import numpy as np
from scipy.spatial.transform import Rotation as R

# RLBench / environment helpers
from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ---- Pre-implemented low-level skills (DO NOT re-implement) -----------------
from skill_code import rotate, move, pick, pull, place
# -----------------------------------------------------------------------------


# -----------------------------------------------------------------------------
#  HELPER: decide which available object we will treat as “rubbish”
# -----------------------------------------------------------------------------
def _select_rubbish(positions_dict):
    """
    Simple heuristic:
      • If a key literally called 'rubbish' exists, use that.
      • Otherwise, pick the first key that starts with “item”.
      • If neither exist, return None (caller must handle).
    """
    if "rubbish" in positions_dict:
        return "rubbish"
    for name in positions_dict.keys():
        if name.lower().startswith("item"):
            return name
    return None


# -----------------------------------------------------------------------------
#  MAIN ORACLE CONTROLLER – executes the 7-step plan from the specification
# -----------------------------------------------------------------------------
def run_combined_task():
    """
    Composite drawer-and-disposal task:

      1) rotate         gripper : 0°  →  90° about world-Z
      2) move           nowhere     →  bottom_side_pos
      3) move           side        →  bottom_anchor_pos
      4) pick-drawer    (grasp handle at anchor position)
      5) pull           drawer open (+X, 20 cm)
      6) pick           rubbish from tabletop
      7) place          rubbish into bin
    """
    print("\n================  START  COMBINED  TASK  ================\n")

    # -------------------------------------------------------------------------
    # 1.  Initialise environment & enable optional video recording
    # -------------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Wrap step / observation for recording (safe if evaluator ignores video)
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------------------------------------------------------------------
        # 2.  Query current 3-D positions of all relevant objects
        # ---------------------------------------------------------------------
        positions = get_object_positions()

        mandatory_keys = ["bottom_side_pos", "bottom_anchor_pos", "bin"]
        missing_keys   = [k for k in mandatory_keys if k not in positions]
        if missing_keys:
            raise KeyError(f"[Oracle] Missing required position keys: {missing_keys}")

        side_pos   = np.asarray(positions["bottom_side_pos"], dtype=np.float32)
        anchor_pos = np.asarray(positions["bottom_anchor_pos"], dtype=np.float32)
        bin_pos    = np.asarray(positions["bin"],               dtype=np.float32)

        rubbish_key = _select_rubbish(positions)
        if rubbish_key is None:
            raise KeyError("[Oracle] Unable to determine the rubbish object.")
        rubbish_pos = np.asarray(positions[rubbish_key], dtype=np.float32)
        print(f"[Oracle] Treating ‘{rubbish_key}’ as rubbish.")

        # ---------------------------------------------------------------------
        # 3.  Execute the 7-step oracle plan
        # ---------------------------------------------------------------------

        # STEP-1 : ROTATE gripper to +90° around Z
        print("\n----  STEP 1 / 7  : ROTATE  ----")
        target_quat = R.from_euler("xyz", [0, 0, 90], degrees=True).as_quat()
        obs, reward, done = rotate(
            env, task,
            target_quat=target_quat,
            max_steps=120,
            threshold=0.05,
            timeout=10.0
        )
        if done:
            print("[Oracle] Episode finished early during ROTATE.")
            return

        # STEP-2 : MOVE → side-position of bottom drawer
        print("\n----  STEP 2 / 7  : MOVE to bottom_side_pos  ----")
        obs, reward, done = move(
            env, task,
            target_pos=side_pos,
            max_steps=120,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Oracle] Episode finished early during MOVE-to-side.")
            return

        # STEP-3 : MOVE → anchor-position (handle) of bottom drawer
        print("\n----  STEP 3 / 7  : MOVE to bottom_anchor_pos  ----")
        obs, reward, done = move(
            env, task,
            target_pos=anchor_pos,
            max_steps=120,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Oracle] Episode finished early during MOVE-to-anchor.")
            return

        # STEP-4 : PICK drawer handle (use small Y-axis approach)
        print("\n----  STEP 4 / 7  : PICK drawer handle  ----")
        obs, reward, done = pick(
            env, task,
            target_pos=anchor_pos,
            approach_distance=0.08,
            max_steps=120,
            threshold=0.01,
            approach_axis="y",   # front-on approach (+Y)
            timeout=10.0
        )
        if done:
            print("[Oracle] Episode finished early during PICK-drawer.")
            return

        # STEP-5 : PULL drawer open (+X, 0.20 m)
        print("\n----  STEP 5 / 7  : PULL drawer  ----")
        obs, reward, done = pull(
            env, task,
            pull_distance=0.20,
            pull_axis="x",
            max_steps=120,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Oracle] Episode finished early during PULL.")
            return

        # STEP-6 : PICK rubbish from the table (vertical Z-axis approach)
        print("\n----  STEP 6 / 7  : PICK rubbish  ----")
        obs, reward, done = pick(
            env, task,
            target_pos=rubbish_pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis="z",
            timeout=10.0
        )
        if done:
            print("[Oracle] Episode finished early during PICK-rubbish.")
            return

        # STEP-7 : PLACE rubbish into bin
        print("\n----  STEP 7 / 7  : PLACE rubbish into bin  ----")
        obs, reward, done = place(
            env, task,
            target_pos=bin_pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis="z",
            timeout=10.0
        )

        # ---------------------------------------------------------------------
        # 4.  Outcome summary
        # ---------------------------------------------------------------------
        if done:
            print("\n[Oracle] Task reported DONE after PLACE. Reward:", reward)
        else:
            print("\n[Oracle] Task did not signal done (may be expected). Reward:", reward)

    except Exception as exc:
        print(f"[Oracle] Exception: {exc}")
        raise
    finally:
        shutdown_environment(env)
        print("\n================   END  COMBINED  TASK   ================\n")


# -----------------------------------------------------------------------------  
#  Script entry-point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    run_combined_task()