import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# --- skills (already provided – DO NOT redefine) -----------------
from skill_code import rotate, move, pick, pull, place
# -----------------------------------------------------------------


# -----------------------------------------------------------------
# Helper utilities
# -----------------------------------------------------------------
def _safe_get(dct, *keys):
    """
    Return the first existing/non-None value for the candidate keys.
    Raises KeyError if none of the keys are present.
    """
    for k in keys:
        if k in dct and dct[k] is not None:
            return dct[k]
    raise KeyError(f"None of the keys {keys} exist in the given dictionary.")


# -----------------------------------------------------------------
# Main execution function (oracle plan)
# -----------------------------------------------------------------
def run_combined_task():
    """
    Executes the oracle plan required to:
        • Rotate the gripper 90° (zero → ninety)
        • Open the bottom drawer fully
        • Pick the rubbish from the drawer / table
        • Drop the rubbish into the bin

    The high-level sequence exactly follows the Specification steps (1‥7):
        1) rotate
        2) move-to-side
        3) move-to-anchor
        4) pick-drawer        (implemented with the generic “pick” skill)
        5) pull               (open the drawer)
        6) pick    (rubbish)
        7) place   (into bin)
    """
    print("===== Starting Combined Task =====")

    # --------------------------------------------------------------
    # 1.  Environment initialisation
    # --------------------------------------------------------------
    env, task = setup_environment()
    try:
        # Reset task (observation + language descriptions, if any)
        descriptions, obs = task.reset()

        # Optional – video capture for debugging / demonstration
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ----------------------------------------------------------
        # 2.  Retrieve relevant object positions
        # ----------------------------------------------------------
        positions = get_object_positions()

        # Drawer (bottom) – position aliases
        bottom_side   = _safe_get(
            positions,
            "bottom_side_pos",   "side_pos_bottom",   "side-pos-bottom"
        )
        bottom_anchor = _safe_get(
            positions,
            "bottom_anchor_pos", "anchor_pos_bottom", "anchor-pos-bottom"
        )

        # Rubbish lying inside drawer or on table – fall-backs
        rubbish_pos   = _safe_get(
            positions,
            "item3", "item2", "item1", "rubbish", "trash", "garbage"
        )

        # Bin position where rubbish must be dropped
        bin_pos       = _safe_get(
            positions,
            "bin", "bin_pos", "trash_bin", "trash"
        )

        # ----------------------------------------------------------
        # 3.  Execute oracle plan (7 steps from the Specification)
        # ----------------------------------------------------------
        # (Step numbers correspond to the JSON specification list.)

        # Step-1 : rotate(gripper, zero_deg → ninety_deg around Z)
        target_quat = R.from_euler('z', 90, degrees=True).as_quat()  # xyzw
        print("\n--- Step-1  rotate  ---")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Task] Terminated right after rotate.")
            return

        # Step-2 : move-to-side(gripper, bottom drawer side position)
        print("\n--- Step-2  move-to-side ---")
        obs, reward, done = move(env, task, np.asarray(bottom_side))
        if done:
            print("[Task] Terminated after moving to side position.")
            return

        # Step-3 : move-to-anchor(gripper, bottom drawer anchor position)
        print("\n--- Step-3  move-to-anchor ---")
        obs, reward, done = move(env, task, np.asarray(bottom_anchor))
        if done:
            print("[Task] Terminated after moving to anchor position.")
            return

        # Step-4 : pick-drawer(gripper closes on drawer handle)
        print("\n--- Step-4  pick-drawer ---")
        obs, reward, done = pick(env, task, np.asarray(bottom_anchor))
        if done:
            print("[Task] Terminated after picking drawer handle.")
            return

        # Step-5 : pull(gripper, bottom drawer) – open along +X for 15 cm
        print("\n--- Step-5  pull (open drawer) ---")
        obs, reward, done = pull(env, task, pull_distance=0.15, pull_axis='x')
        if done:
            print("[Task] Terminated after pulling drawer.")
            return

        # Step-6 : pick(rubbish)
        print("\n--- Step-6  pick (rubbish) ---")
        # Move close to the rubbish first
        obs, reward, done = move(env, task, np.asarray(rubbish_pos))
        if done:
            print("[Task] Terminated while moving towards rubbish.")
            return
        obs, reward, done = pick(env, task, np.asarray(rubbish_pos))
        if done:
            print("[Task] Terminated after picking rubbish.")
            return

        # Step-7 : place(rubbish, bin)
        print("\n--- Step-7  place (into bin) ---")
        obs, reward, done = move(env, task, np.asarray(bin_pos))
        if done:
            print("[Task] Terminated while moving to bin.")
            return
        obs, reward, done = place(env, task, np.asarray(bin_pos))

        # ----------------------------------------------------------
        # 4.  Final status
        # ----------------------------------------------------------
        if done:
            print("[Task] Task completed successfully! Reward:", reward)
        else:
            # Some environments leave `done` False even on success.
            print("[Task] Plan executed. Environment reports done = False.")

    except Exception as exc:
        print("!! Exception during task execution:", repr(exc))
        raise

    finally:
        shutdown_environment(env)

    print("===== End of Combined Task =====")


# -----------------------------------------------------------------
# Entry-point guard
# -----------------------------------------------------------------
if __name__ == "__main__":
    run_combined_task()