import numpy as np

from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pick, pull, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------
def _safe_get_position(obj_dict, name):
    """
    Extract pure XYZ position (np.ndarray, shape=(3,)) from the dictionary
    returned by `get_object_positions()`.

    The helper is robust to several possible dictionary layouts:
       1) {'position': np.ndarray([...])}
       2) {'pos': [...] }
       3) {'xyz': [...] }
       4) the entry itself already being an (x, y, z) list / tuple / ndarray
    """
    entry = obj_dict[name]

    # Case 4 – already an ndarray / list / tuple
    if isinstance(entry, (list, tuple, np.ndarray)) and len(entry) == 3:
        return np.asarray(entry, dtype=np.float32)

    # Case 1 / 2 / 3 – dictionary‑style layouts
    for key in ("position", "pos", "xyz"):
        if key in entry and len(entry[key]) == 3:
            return np.asarray(entry[key], dtype=np.float32)

    # If nothing matched, raise a descriptive error
    raise ValueError(f"[ObjectPositions] Cannot extract XYZ position for key '{name}'. "
                     f"Entry={entry}")


def _angle_name_to_quat(name):
    """
    Map logical angle names used in the PDDL / specification to concrete
    quaternions (xyzw) expected by the `rotate` skill.

    The mapping assumes:
      • 'zero_deg'   : identity quaternion  (no rotation)
      • 'ninety_deg' : 90° rotation about Z‑axis
                       (x=0, y=0, z=sin(45°)=0.7071, w=cos(45°)=0.7071)

    If additional discrete angles are ever added, extend this helper.
    """
    name = name.lower().strip()
    if name in ("zero_deg", "zero", "0", "0deg"):
        return np.asarray([0.0, 0.0, 0.0, 1.0], dtype=np.float32)

    if name in ("ninety_deg", "90", "90deg"):
        # R.from_euler('z', 90°, degrees=True).as_quat()  →  [0, 0, 0.707, 0.707]
        return np.asarray([0.0, 0.0, 0.70710678, 0.70710678], dtype=np.float32)

    raise ValueError(f"[AngleMapping] Unsupported discrete angle name '{name}'.")


# ---------------------------------------------------------------------
# Oracle‑plan execution
# ---------------------------------------------------------------------
def run_combined_task():
    print("==========  Combined‑Domain Oracle Plan Start ==========")

    # -----------------------------------------------------------------
    # 0) Environment initialisation
    # -----------------------------------------------------------------
    env, task = setup_environment()

    try:
        descriptions, obs = task.reset()

        # Optional: enable video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -----------------------------------------------------------------
        # 1) Fetch all required object positions up‑front
        # -----------------------------------------------------------------
        obj_pos = get_object_positions()

        # Named keys expected by the oracle plan
        bottom_side_xyz   = _safe_get_position(obj_pos, 'bottom_side_pos')
        bottom_anchor_xyz = _safe_get_position(obj_pos, 'bottom_anchor_pos')
        rubbish_xyz       = _safe_get_position(obj_pos, 'rubbish')
        bin_xyz           = _safe_get_position(obj_pos, 'bin')

        # -----------------------------------------------------------------
        # 2) Oracle plan (7 discrete steps)
        # -----------------------------------------------------------------
        print("\n--- Step 1 / 7 : rotate gripper zero_deg → ninety_deg ---")
        target_quat = _angle_name_to_quat("ninety_deg")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Early‑Exit] Task finished during rotate.")
            return

        print("\n--- Step 2 / 7 : move gripper nowhere‑pos → bottom_side_pos ---")
        # We directly use bottom_side_xyz – the skill moves the end‑effector
        obs, reward, done = move(env, task, bottom_side_xyz)
        if done:
            print("[Early‑Exit] Task finished during move‑to‑side.")
            return

        print("\n--- Step 3 / 7 : move gripper side‑pos → bottom_anchor_pos ---")
        obs, reward, done = move(env, task, bottom_anchor_xyz)
        if done:
            print("[Early‑Exit] Task finished during move‑to‑anchor.")
            return

        print("\n--- Step 4 / 7 : pick (grab bottom drawer handle) ---")
        obs, reward, done = pick(env, task, bottom_anchor_xyz,
                                 approach_distance=0.07,  # a bit closer for handles
                                 approach_axis='-z')       # come from above
        if done:
            print("[Early‑Exit] Task finished during pick‑drawer.")
            return

        print("\n--- Step 5 / 7 : pull drawer open (≈15 cm along +x) ---")
        # Positive X is an educated guess – adjust if your environment differs
        obs, reward, done = pull(env, task, pull_distance=0.15, pull_axis='x')
        if done:
            print("[Early‑Exit] Task finished during pull.")
            return

        # Drawer is now open – continue to disposal mission
        print("\n--- Step 6 / 7 : pick rubbish on table ---")
        obs, reward, done = pick(env, task, rubbish_xyz,
                                 approach_distance=0.15,
                                 approach_axis='z')         # conventional top‑down pick
        if done:
            print("[Early‑Exit] Task finished during pick‑rubbish.")
            return

        print("\n--- Step 7 / 7 : place rubbish into bin ---")
        obs, reward, done = place(env, task, bin_xyz,
                                  approach_distance=0.20,
                                  approach_axis='z')
        if done:
            print("[Success] Plan executed completely. Reward:", reward)
        else:
            print("[Info] Plan finished, but environment reports done=False.")

    finally:
        shutdown_environment(env)
        print("==========  Combined‑Domain Oracle Plan End ==========")


if __name__ == "__main__":
    run_combined_task()
