import numpy as np
from math import pi

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# Only use the predefined, already-implemented skills.
from skill_code import rotate, move, pick, pull, place


# --------------------------------------------------------------------------- #
#  Helper utilities                                                           #
# --------------------------------------------------------------------------- #
def _safe_pos(pos_dict, key, fallback=None):
    """
    Convenience wrapper around get_object_positions().

    Parameters
    ----------
    pos_dict : dict(str → (x, y, z) | None)
        Dictionary returned by get_object_positions().
    key : str
        Key that must be contained in `pos_dict`.
    fallback : iterable[float] | None
        Used if the key is missing or value is None.
        If both the key is missing and fallback is None, raise KeyError.

    Returns
    -------
    np.ndarray
        3-D float32 numpy array.
    """
    if key in pos_dict and pos_dict[key] is not None:
        return np.asarray(pos_dict[key], dtype=np.float32)
    if fallback is not None:
        return np.asarray(fallback, dtype=np.float32)
    raise KeyError(f"[run_task] Missing object position for key: “{key}”")


# --------------------------------------------------------------------------- #
#  Main task logic                                                            #
# --------------------------------------------------------------------------- #
def run_task_open_drawer_and_dispose():
    """
    Oracle plan that:
        1) Rotates the gripper to 90° about the Z-axis.
        2) Moves to the drawer’s side position.
        3) Slides to the drawer’s anchor/handle.
        4) Grasps the handle (pick-drawer).
        5) Pulls the drawer open.
        6) Picks the piece of trash from the table.
        7) Places the trash into the bin.
    """
    print("\n================  Start Task: Open Drawer & Dispose  ================\n")

    # ---------------------------------------------------------------------- #
    #  Environment initialisation                                            #
    # ---------------------------------------------------------------------- #
    env, task = setup_environment()
    try:
        # RLBench (or similar) task reset
        _, obs = task.reset()

        # Optional video capture
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------ #
        #  Query object positions                                            #
        # ------------------------------------------------------------------ #
        positions = get_object_positions()

        # Drawer related
        bottom_side_pos   = _safe_pos(positions, 'bottom_side_pos')
        bottom_anchor_pos = _safe_pos(positions, 'bottom_anchor_pos')

        # Bin
        bin_pos = _safe_pos(positions, 'bin')

        # Trash (could be “rubbish” or fallback to a generic item)
        trash_key = 'rubbish' if 'rubbish' in positions and positions['rubbish'] \
                    else 'item3'
        trash_pos = _safe_pos(positions, trash_key)

        # ------------------------------------------------------------------ #
        #  Execute the oracle plan                                           #
        # ------------------------------------------------------------------ #
        print("=====================  Oracle Plan  =====================\n")

        # STEP-1  rotate(gripper, zero_deg, ninety_deg)
        ninety_deg_quat = np.array([0.0, 0.0, np.sin(pi / 4.0), np.cos(pi / 4.0)],
                                   dtype=np.float32)
        obs, reward, done = rotate(env, task, ninety_deg_quat)
        if done:
            print("[Task] Terminated during rotation.")
            return

        # STEP-2  move-to-side (bottom drawer)
        obs, reward, done = move(env, task, bottom_side_pos)
        if done:
            print("[Task] Terminated during move-to-side.")
            return

        # STEP-3  move-to-anchor (drawer handle)
        obs, reward, done = move(env, task, bottom_anchor_pos)
        if done:
            print("[Task] Terminated during move-to-anchor.")
            return

        # STEP-4  pick-drawer (generic pick, very small approach distance)
        obs, reward, done = pick(
            env, task,
            target_pos=bottom_anchor_pos,
            approach_distance=0.05,
            approach_axis='y'          # approach along local +Y (handle)
        )
        if done:
            print("[Task] Terminated during pick-drawer.")
            return

        # STEP-5  pull drawer out
        obs, reward, done = pull(
            env, task,
            pull_distance=0.18,        # empirically generous distance
            pull_axis='x'
        )
        if done:
            print("[Task] Terminated during pull.")
            return

        # STEP-6  move above the trash (hover position to avoid collision)
        trash_hover = trash_pos + np.array([0.0, 0.0, 0.15], dtype=np.float32)
        obs, reward, done = move(env, task, trash_hover)
        if done:
            print("[Task] Terminated during move-above-trash.")
            return

        # STEP-7  pick the trash itself
        obs, reward, done = pick(
            env, task,
            target_pos=trash_pos,
            approach_distance=0.12,
            approach_axis='-z'
        )
        if done:
            print("[Task] Terminated during pick-trash.")
            return

        # STEP-8  move above the bin
        bin_hover = bin_pos + np.array([0.0, 0.0, 0.15], dtype=np.float32)
        obs, reward, done = move(env, task, bin_hover)
        if done:
            print("[Task] Terminated during move-above-bin.")
            return

        # STEP-9  place trash into the bin
        obs, reward, done = place(
            env, task,
            target_pos=bin_pos,
            approach_distance=0.12,
            approach_axis='-z'
        )

        # ------------------------------------------------------------------ #
        #  Outcome                                                           #
        # ------------------------------------------------------------------ #
        if done:
            print("\n[Task] ✔  Task completed successfully! Reward:", reward)
        else:
            print("\n[Task] Plan executed, but environment reports done=False.")

    finally:
        # Always ensure a graceful shutdown
        shutdown_environment(env)
        print("\n================  Shutdown & Cleanup  ================\n")


# --------------------------------------------------------------------------- #
#  Entry-point guard                                                          #
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    run_task_open_drawer_and_dispose()