import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# === Import ONLY the predefined skills exactly as provided ===
from skill_code import move, pick, place, rotate, pull


def _safe_get_position(name: str, positions_dict: dict) -> np.ndarray:
    """
    Resolve a 3-D world-frame position for a named Shape.

    Priority:
      1) cached dictionary returned by get_object_positions()
      2) direct query to CoppeliaSim (via pyrep Shape)

    Raises
    ------
    RuntimeError
        If the position cannot be resolved.
    """
    if name in positions_dict:
        return np.asarray(positions_dict[name], dtype=np.float32)

    # Fallback – ask the simulator directly
    try:
        from pyrep.objects.shape import Shape
        return np.asarray(Shape(name).get_position(), dtype=np.float32)
    except Exception as exc:
        raise RuntimeError(f"[run_task] Cannot resolve position for object "
                           f"'{name}'.") from exc


def _open_gripper(env, task, n_steps: int = 5) -> None:
    """
    Utility that keeps the end-effector fixed while opening the gripper.

    Parameters
    ----------
    n_steps : int
        Number of simulation steps to apply the open command.
    """
    obs = task.get_observation()
    fixed_pos  = obs.gripper_pose[:3]
    fixed_quat = obs.gripper_pose[3:7]

    action = np.zeros(env.action_shape, dtype=np.float32)
    action[:3]   = fixed_pos
    action[3:7]  = fixed_quat
    action[-1]   = 1.0                       #  +1 → open

    for _ in range(n_steps):
        obs, _, done = task.step(action)
        if done:
            break


def run_skeleton_task() -> None:
    """
    Execute the oracle plan that (1) opens a drawer and (2) disposes of the
    rubbish from the table into the bin.

    The exact sequence is governed by the Specification:
        1) rotate
        2) move-to-side
        3) move-to-anchor
        4) pick-drawer
        5) pull
        6) pick  (rubbish)
        7) place (into bin)
    """
    print("\n==========  RUNNING  COMBINED  TASK  ==========")

    # ----------------------------------------------------------------------
    # 1)  Environment initialisation & optional video recording set-up
    # ----------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional video recording
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # 2)  Gather static object poses
        # ------------------------------------------------------------------
        positions = get_object_positions()  # May be empty if not pre-generated

        # Drawer helper points (we use the *bottom* drawer for this example)
        side_pos_bottom   = _safe_get_position("bottom_side_pos",   positions)
        anchor_pos_bottom = _safe_get_position("bottom_anchor_pos", positions)

        # Rubbish on the table and the bin location
        rubbish_pos = _safe_get_position("rubbish", positions)
        bin_pos     = _safe_get_position("bin",     positions)

        # ------------------------------------------------------------------
        # 3)  Execute oracle plan (see specification)
        # ------------------------------------------------------------------
        done = False

        # ---- Step 1 : rotate(gripper, zero_deg → ninety_deg) ---------------
        quat_90z = R.from_euler('z', 90, degrees=True).as_quat()
        print("\n[Plan-Step 1] Rotating gripper to 90° about Z")
        obs, reward, done = rotate(env, task, target_quat=quat_90z)
        if done:
            print("[run_task] Episode finished during rotate.")
            return

        # ---- Step 2 : move-to-side(gripper → side-pos) ----------------------
        print("\n[Plan-Step 2] Moving to drawer side-handle point:",
              side_pos_bottom)
        obs, reward, done = move(env, task, target_pos=side_pos_bottom)
        if done:
            print("[run_task] Episode finished during move-to-side.")
            return

        # ---- Step 3 : move-to-anchor(gripper → anchor-pos) ------------------
        print("\n[Plan-Step 3] Moving to drawer anchor point:",
              anchor_pos_bottom)
        obs, reward, done = move(env, task, target_pos=anchor_pos_bottom)
        if done:
            print("[run_task] Episode finished during move-to-anchor.")
            return

        # ---- Step 4 : pick-drawer(gripper, bottom, anchor-pos) -------------
        print("\n[Plan-Step 4] Grasping drawer handle (pick-drawer).")
        obs, reward, done = pick(
            env,
            task,
            target_pos=anchor_pos_bottom,
            approach_distance=0.10,
            approach_axis='z'
        )
        if done:
            print("[run_task] Episode finished during pick-drawer.")
            return

        # ---- Step 5 : pull(gripper, bottom) --------------------------------
        print("\n[Plan-Step 5] Pulling drawer outwards.")
        obs, reward, done = pull(
            env,
            task,
            pull_distance=0.20,    # ~20 cm
            pull_axis='x'
        )
        if done:
            print("[run_task] Episode finished during pull.")
            return

        # ---- Release the drawer handle so we can grasp the rubbish ---------
        print("[Plan] Releasing drawer handle.")
        _open_gripper(env, task, n_steps=5)

        # Optional: retreat slightly upward to avoid collisions with drawer
        retreat_pos = anchor_pos_bottom + np.array([0.0, 0.0, 0.10],
                                                   dtype=np.float32)
        obs, reward, done = move(env, task, target_pos=retreat_pos)
        if done:
            print("[run_task] Episode finished during retreat.")
            return

        # ---- Step 6 : pick(rubbish) ----------------------------------------
        print("\n[Plan-Step 6] Picking up rubbish at:", rubbish_pos)
        obs, reward, done = pick(
            env,
            task,
            target_pos=rubbish_pos,
            approach_distance=0.15,
            approach_axis='z'
        )
        if done:
            print("[run_task] Episode finished during pick-rubbish.")
            return

        # ---- Step 7 : place(rubbish → bin) ---------------------------------
        print("\n[Plan-Step 7] Placing rubbish in bin at:", bin_pos)
        obs, reward, done = place(
            env,
            task,
            target_pos=bin_pos,
            approach_distance=0.15,
            approach_axis='z'
        )
        if done:
            print("[run_task] Episode finished during place-rubbish.")
        else:
            print("[run_task] Plan executed; environment did not set done=True.")

    finally:
        # ------------------------------------------------------------------
        # 4)  Always shut down the environment cleanly
        # ------------------------------------------------------------------
        shutdown_environment(env)
        print("==========  TASK  SHUTDOWN  COMPLETE  ==========")


if __name__ == "__main__":
    run_skeleton_task()