import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# Import every predefined skill exactly as delivered in ``skill_code``
from skill_code import *   # brings in move, pick, place, rotate, pull, …

# ------------------------------------------------------------------
# --------------------  Helper / Utility Functions  ----------------
# ------------------------------------------------------------------
def _quat_from_z_rotation(deg: float = 90.0):
    """Return a quaternion (x,y,z,w) that corresponds to a rotation of `deg`
    degrees around the Z-axis (using scipy’s XYZ-convention)."""
    return R.from_euler('z', deg, degrees=True).as_quat()


def _guess_pull_axis(anchor_pos: np.ndarray,
                     joint_pos: np.ndarray,
                     safety_margin: float = 0.05):
    """
    Infer a 1-D axis (“x”, “-x”, “y”, …) and (positive) distance that will pull
    the drawer handle (anchor_pos) directly away from the drawer’s rotational
    joint (joint_pos).

    Returns
    -------
    axis_str : str   One of {'x','-x','y','-y','z','-z'}
    distance : float Distance along that axis
    """
    vec = anchor_pos - joint_pos                    # direction away from pivot
    dominant_idx = int(np.argmax(np.abs(vec)))      # axis with largest comp.
    axis_letter = ('x', 'y', 'z')[dominant_idx]
    sign = '' if vec[dominant_idx] >= 0 else '-'
    axis_str = f'{sign}{axis_letter}'
    # Pull a little further than the pure geometric offset
    distance   = float(np.abs(vec[dominant_idx]) + safety_margin)
    return axis_str, distance


def _get_first_available(positions: dict, *candidates):
    """Return the first positional entry that exists in *positions*."""
    for name in candidates:
        pos = positions.get(name)
        if pos is not None:
            return pos
    return None


# ------------------------------------------------------------------
# ----------------------  Main Oracle-Plan Runner  -----------------
# ------------------------------------------------------------------
def run_skeleton_task():
    """
    Execute the oracle plan described in the specification:

        1) rotate gripper  ->  +90° about Z
        2) move to drawer side position
        3) move to drawer anchor (handle)
        4) pick the handle
        5) pull drawer open
        6) pick tomato-1   → table  → plate
        7) pick tomato-2   → table  → plate
    """
    print("===== Starting Skeleton Task =====")

    # --------------- 1) Environment initialisation -----------------
    env, task = setup_environment()
    try:
        # Reset and obtain the very first observation
        _, obs = task.reset()

        # Optional: enable video recording
        init_video_writers(obs)

        original_step            = task.step
        task.step                = recording_step(original_step)
        original_get_observation = task.get_observation
        task.get_observation     = recording_get_observation(original_get_obs=original_get_observation)

        # ----------- 2) Gather all object positions we may need -----
        positions = get_object_positions()

        # Drawer related positions (bottom drawer chosen by oracle plan)
        bottom_side_pos   = _get_first_available(positions, 'bottom_side_pos')
        bottom_anchor_pos = _get_first_available(positions, 'bottom_anchor_pos')
        bottom_joint_pos  = _get_first_available(positions, 'bottom_joint_pos')

        # Plate and tomatoes
        plate_pos   = _get_first_available(positions, 'plate')
        tomato1_pos = _get_first_available(positions, 'tomato1', 'item1')
        tomato2_pos = _get_first_available(positions, 'tomato2', 'item2')

        # Sanity check – abort early if anything critical is missing
        required = {
            'bottom_side_pos'  : bottom_side_pos,
            'bottom_anchor_pos': bottom_anchor_pos,
            'bottom_joint_pos' : bottom_joint_pos,
            'plate_pos'        : plate_pos,
            'tomato1_pos'      : tomato1_pos,
            'tomato2_pos'      : tomato2_pos
        }
        missing = [k for k, v in required.items() if v is None]
        if missing:
            raise RuntimeError(f"[ERROR] Missing positions for: {', '.join(missing)}")

        # ------------------------------------------------------------------
        # --------------------- 3) Execute oracle plan ----------------------
        # ------------------------------------------------------------------
        done = False

        # STEP-1  rotate(gripper, zero_deg → ninety_deg)
        if not done:
            print("\n[Step-1] Rotate gripper to +90° about Z")
            quat_90 = _quat_from_z_rotation(90.0)
            obs, reward, done = rotate(env, task, target_quat=quat_90)

        # STEP-2  move-to-side
        if not done:
            print("\n[Step-2] Move to bottom_side_pos")
            obs, reward, done = move(env, task, target_pos=bottom_side_pos)

        # STEP-3  move-to-anchor
        if not done:
            print("\n[Step-3] Move to bottom_anchor_pos (handle)")
            obs, reward, done = move(env, task, target_pos=bottom_anchor_pos)

        # STEP-4  pick-drawer (grasp handle)
        if not done:
            print("\n[Step-4] Grasp drawer handle")
            obs, reward, done = pick(env, task,
                                     target_pos=bottom_anchor_pos,
                                     approach_distance=0.10,
                                     approach_axis='z')

        # STEP-5  pull drawer open
        if not done:
            print("\n[Step-5] Pull drawer open")
            axis, dist = _guess_pull_axis(bottom_anchor_pos, bottom_joint_pos)
            obs, reward, done = pull(env, task,
                                     pull_distance=dist,
                                     pull_axis=axis)

        # -------------------- 4) Handle the tomatoes -----------------------
        tomatoes = [('tomato1', tomato1_pos),
                    ('tomato2', tomato2_pos)]

        for idx, (name, pos) in enumerate(tomatoes, start=6):
            if done:
                break

            # PICK tomato
            print(f"\n[Step-{idx}] Pick {name}")
            obs, reward, done = pick(env, task,
                                     target_pos=pos,
                                     approach_distance=0.12,
                                     approach_axis='-z')
            if done:
                break

            # PLACE onto plate
            print(f"\n[Step-{idx + 1}] Place {name} on plate")
            obs, reward, done = place(env, task,
                                      target_pos=plate_pos,
                                      approach_distance=0.12,
                                      approach_axis='-z')

        # ------------------------------------------------------------------
        # ------------------------ 5) Final report --------------------------
        # ------------------------------------------------------------------
        if done:
            print(f"[Task] RLBench signalled completion (done=True). Final reward: {reward}")
        else:
            print("[Task] Oracle plan executed fully – check goal condition in simulator.")

    finally:
        # Always shut down the environment
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


# Entry-point guard
if __name__ == "__main__":
    run_skeleton_task()