import numpy as np
from scipy.spatial.transform import Rotation as R

# === RLBench / project specific helpers ===
from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pick, pull, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
#  Small local helpers
# ---------------------------------------------------------------------------
def _safe_pos(positions, primary_key, fallback_key=None):
    """
    Return a 3-D numpy array for the requested key, falling back if needed.

    Args
    ----
    positions : dict
        Mapping from object-name strings to (x, y, z) tuples.
    primary_key : str
        Preferred key to look up.
    fallback_key : str or None
        Fallback key if the preferred one is missing.

    Raises
    ------
    KeyError if neither key is found.
    """
    if primary_key in positions:
        return np.asarray(positions[primary_key], dtype=float)
    if fallback_key and fallback_key in positions:
        return np.asarray(positions[fallback_key], dtype=float)
    raise KeyError(
        f"Missing both '{primary_key}'"
        + (f" and fallback '{fallback_key}'" if fallback_key else "")
        + " in object positions."
    )


def _infer_pull_axis_and_distance(anchor, joint):
    """
    Given two 3-D positions (the drawer handle anchor and the drawer’s joint or
    end-position after opening), return (axis_string, positive_distance)
    representing the dominant axis along which the drawer should be pulled.
    """
    diff = joint - anchor
    axis_id = int(np.argmax(np.abs(diff)))        # 0:x, 1:y, 2:z
    axis_letter = ['x', 'y', 'z'][axis_id]
    distance = diff[axis_id]
    if distance < 0:
        axis_letter = f'-{axis_letter}'
        distance = -distance
    return axis_letter, float(distance)


# ---------------------------------------------------------------------------
#  Main task logic
# ---------------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()

    try:
        # ------------------------------------------------------------------
        #  1) environment reset & video recording wrappers
        # ------------------------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)

        # Wrap step / observation for video capture
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        #  2) query all relevant positions from helper
        # ------------------------------------------------------------------
        # get_object_positions() must be called AFTER reset so that all
        # simulated objects have settled to their correct start pose.
        positions = get_object_positions()

        # Drawer information (we use the bottom drawer as “any available”)
        bottom_side_pos   = _safe_pos(positions, 'bottom_side_pos')
        bottom_anchor_pos = _safe_pos(positions, 'bottom_anchor_pos')
        bottom_joint_pos  = _safe_pos(positions, 'bottom_joint_pos')

        # Target receptacle (plate) and tomatoes
        plate_pos   = _safe_pos(positions, 'plate')
        tomato1_pos = _safe_pos(positions, 'tomato1', 'item1')
        tomato2_pos = _safe_pos(positions, 'tomato2', 'item2')

        # ------------------------------------------------------------------
        #  3) execute oracle plan (Specification steps 1-9)
        # ------------------------------------------------------------------
        done = False
        reward = 0.0

        # (step-1) rotate gripper from zero_deg ➔ ninety_deg
        ninety_deg_quat = R.from_euler('z', 90, degrees=True).as_quat()  # (x,y,z,w)
        obs, reward, done = rotate(env, task, target_quat=ninety_deg_quat)
        if done:
            print("[Task] Terminated immediately after rotation.")
            return

        # (step-2) move-to-side of bottom drawer
        obs, reward, done = move(env, task, target_pos=bottom_side_pos)
        if done:
            print("[Task] Terminated after move-to-side.")
            return

        # (step-3) move-to-anchor (handle)
        obs, reward, done = move(env, task, target_pos=bottom_anchor_pos)
        if done:
            print("[Task] Terminated after move-to-anchor.")
            return

        # (step-4) pick-drawer (grab handle)
        obs, reward, done = pick(env, task,
                                 target_pos=bottom_anchor_pos,
                                 approach_axis='z')          # approach from above
        if done:
            print("[Task] Terminated during drawer pick.")
            return

        # (step-5) pull drawer open
        pull_axis, pull_dist = _infer_pull_axis_and_distance(
            bottom_anchor_pos, bottom_joint_pos
        )
        obs, reward, done = pull(env, task,
                                 pull_distance=pull_dist,
                                 pull_axis=pull_axis)
        if done:
            print("[Task] Terminated after pull.")
            return

        # -------------------------- TOMATO #1 ------------------------------
        obs, reward, done = pick(env, task,
                                 target_pos=tomato1_pos,
                                 approach_axis='z')
        if done:
            print("[Task] Aborted during tomato-1 pick.")
            return

        obs, reward, done = place(env, task,
                                  target_pos=plate_pos,
                                  approach_axis='z')
        if done:
            print("[Task] Aborted during tomato-1 place.")
            return

        # -------------------------- TOMATO #2 ------------------------------
        obs, reward, done = pick(env, task,
                                 target_pos=tomato2_pos,
                                 approach_axis='z')
        if done:
            print("[Task] Aborted during tomato-2 pick.")
            return

        obs, reward, done = place(env, task,
                                  target_pos=plate_pos,
                                  approach_axis='z')

        # ------------------------------------------------------------------
        #  4) final status
        # ------------------------------------------------------------------
        if done:
            print(f"[Task] Goal achieved! Final reward: {reward}")
        else:
            print("[Task] Plan executed; environment did not signal completion (done=False).")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()