import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ==== Skills (already provided in skill_code) ====
from skill_code import rotate, move, pick, pull, place


# ----------------------------------------------------------------------
#  Utility helpers (local only, do NOT redefine skills)
# ----------------------------------------------------------------------
def _make_quat_from_euler(rx=0.0, ry=0.0, rz=0.0, seq='xyz'):
    """Return a quaternion (xyzw) from Euler angles in °."""
    return R.from_euler(seq,
                        [np.deg2rad(rx), np.deg2rad(ry), np.deg2rad(rz)]
                        ).as_quat()


def _safe_lookup(dic, *candidates):
    """Return first key in `candidates` that exists in `dic`."""
    for k in candidates:
        if k in dic:
            return dic[k]
    raise KeyError(f'None of the keys {candidates} found!')


# ======================================================================
#                     MAIN ORACLE PLAN EXECUTOR
# ======================================================================
def run_open_drawer_and_dispose():
    """
    Oracle plan that fulfils: “Pull open any drawer that is not locked,
    then drop the rubbish into the bin.”

    The concrete sequence realises the specification steps exactly with
    the predefined skills:
        1) rotate      (→ gripper to 90°)
        2) move        (→ side-position of bottom drawer)
        3) move        (→ anchor-position, the drawer handle)
        4) pick        (→ closes gripper on the handle, i.e. pick-drawer)
        5) pull        (→ open the drawer)
        6) pick        (→ grab the rubbish)
        7) place       (→ drop rubbish into the bin)
    """
    print("\n===== [Task] Start: Open Drawer & Dispose Rubbish =====")

    # ----------------------------------------------------------
    #  Environment / task initialisation
    # ----------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # ----- Optional video capture -----
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # ----------------------------------------------------------
        #  Retrieve object positions (helper module)
        # ----------------------------------------------------------
        pos = get_object_positions()

        # ---- Drawer related (choose bottom drawer as default) ----
        side_pos   = np.asarray(_safe_lookup(pos,
                                             'bottom_side_pos', 'side_pos_bottom'))
        anchor_pos = np.asarray(_safe_lookup(pos,
                                             'bottom_anchor_pos', 'anchor_pos_bottom'))
        joint_pos  = np.asarray(_safe_lookup(pos,
                                             'bottom_joint_pos'))   # reference for pull axis

        # ---- Rubbish & Bin ----
        rubbish_pos = np.asarray(_safe_lookup(pos,
                                              'rubbish', 'item3'))
        bin_pos     = np.asarray(_safe_lookup(pos, 'bin'))

        # ----------------------------------------------------------
        #                ====  Execute plan  ====
        # ----------------------------------------------------------

        # 1) ROTATE gripper from zero to ninety degrees (around Z)
        target_quat = _make_quat_from_euler(rz=90)
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Abort] Task terminated after rotate.")
            return

        # 2) MOVE to drawer side-position
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("[Abort] Task terminated after move-to-side.")
            return

        # 3) MOVE to drawer anchor-position (handle)
        obs, reward, done = move(env, task, anchor_pos)
        if done:
            print("[Abort] Task terminated after move-to-anchor.")
            return

        # 4) PICK (close gripper) on the handle  → acts as pick-drawer
        obs, reward, done = pick(env, task,
                                 target_pos=anchor_pos,
                                 approach_distance=0.08,
                                 approach_axis='z')
        if done:
            print("[Abort] Task terminated after pick-drawer.")
            return

        # 5) PULL drawer open
        #    – Determine axis & distance automatically from joint ↔ handle vector
        diff_vec  = joint_pos - anchor_pos
        axis_idx  = int(np.argmax(np.abs(diff_vec)))          # 0 = x, 1 = y, 2 = z
        axis_char = ['x', 'y', 'z'][axis_idx]
        if diff_vec[axis_idx] < 0:
            axis_char = f'-{axis_char}'
        pull_distance = np.abs(diff_vec[axis_idx]) + 0.05     # small safety margin

        obs, reward, done = pull(env, task,
                                 pull_distance=pull_distance,
                                 pull_axis=axis_char)
        if done:
            print("[Abort] Task terminated while pulling drawer.")
            return

        # 6) PICK the rubbish
        rubbish_above = rubbish_pos + np.array([0.0, 0.0, 0.10])  # hover 10 cm above
        obs, reward, done = move(env, task, rubbish_above)
        if done:
            print("[Abort] Task terminated while approaching rubbish.")
            return

        obs, reward, done = pick(env, task,
                                 target_pos=rubbish_pos,
                                 approach_distance=0.08,
                                 approach_axis='-z')
        if done:
            print("[Abort] Task terminated after picking rubbish.")
            return

        # 7) PLACE rubbish into the bin
        bin_above = bin_pos + np.array([0.0, 0.0, 0.10])
        obs, reward, done = move(env, task, bin_above)
        if done:
            print("[Abort] Task terminated while moving above bin.")
            return

        obs, reward, done = place(env, task,
                                  target_pos=bin_pos,
                                  approach_distance=0.08,
                                  approach_axis='-z')
        if done:
            print("[Task] Completed successfully! Reward:", reward)
        else:
            print("[Task] Finished sequence (done == False).")

    except Exception as exc:
        print("[Exception] An error occurred during execution:", exc)
        raise

    finally:
        shutdown_environment(env)
        print("===== [Task] Environment Shutdown Complete =====")


# ----------------------------------------------------------------------
#  Entry-Point
# ----------------------------------------------------------------------
if __name__ == "__main__":
    run_open_drawer_and_dispose()