import numpy as np
from pyrep.objects.shape import Shape                                   # pylint: disable=unused-import
from pyrep.objects.proximity_sensor import ProximitySensor              # pylint: disable=unused-import

from env import setup_environment, shutdown_environment

# keep the wildcard-import line exactly as in the skeleton
from skill_code import *                                                # noqa: F403, F401

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------------
def _safe_get_pos(name: str, positions: dict):
    """
    Robustly retrieve the 3-D position of an object.  Accepts a few common
    aliasing conventions (underscores / dashes, upper / lower-case).

    Parameters
    ----------
    name : str
        Canonical object key, e.g. 'bottom_side_pos'.
    positions : dict
        Dictionary returned by get_object_positions().

    Returns
    -------
    np.ndarray
        (x, y, z) position.

    Raises
    ------
    KeyError
        If the object key cannot be resolved.
    """
    if name in positions:
        return np.asarray(positions[name], dtype=np.float32)

    alt_keys = {name.replace('_', '-'),
                name.replace('-', '_'),
                name.lower(),
                name.upper()}
    for k in alt_keys:
        if k in positions:
            return np.asarray(positions[k], dtype=np.float32)

    raise KeyError(f"[Task] '{name}' not found in object_positions.")


def _discover_tomatoes(positions: dict, num_expected: int = 2):
    """
    Return a list with (up to) `num_expected` tomato-like objects.  Falls back
    to generic 'item#' naming if explicit 'tomato' objects do not exist.
    """
    tomato_keys = [k for k in positions if 'tomato' in k.lower()]

    if len(tomato_keys) < num_expected:
        tomato_keys.extend([k for k in positions if k.startswith('item')])
        tomato_keys = tomato_keys[:num_expected]

    if len(tomato_keys) < num_expected:
        raise RuntimeError("[Task] Not enough tomato / item objects in scene.")

    return [np.asarray(positions[k], dtype=np.float32) for k in tomato_keys]


# ---------------------------------------------------------------------------
# Main execution logic that follows the 9-step specification
# ---------------------------------------------------------------------------
def run_skeleton_task():
    """Execute oracle plan that opens a drawer, then moves all tomatoes to plate."""
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------
    #  Environment initialisation
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        _, obs = task.reset()
        init_video_writers(obs)

        # Wrap step / get_observation so that we automatically record video
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # Convenience lambda to fetch fresh positions whenever objects move
        refresh_positions = lambda: get_object_positions()

        # ------------------------------------------------------------------
        # STEP 1  – rotate(gripper, zero_deg → ninety_deg) about Z axis
        # ------------------------------------------------------------------
        print("\n[Plan] Step 1 – rotate gripper 90° about Z")
        angle_rad = np.deg2rad(90.0)                      # 90°
        target_quat = np.asarray([0.0, 0.0,
                                  np.sin(angle_rad / 2),
                                  np.cos(angle_rad / 2)],
                                 dtype=np.float32)
        obs, reward, done = rotate(env, task, target_quat)             # noqa: F405
        if done:
            print("[Plan] Terminated during rotate.")
            return

        # ------------------------------------------------------------------
        # STEP 2  – move-to-side(bottom_side_pos)
        # ------------------------------------------------------------------
        print("\n[Plan] Step 2 – move to bottom_side_pos")
        side_pos = _safe_get_pos("bottom_side_pos", refresh_positions())
        obs, reward, done = move(env, task, side_pos)                  # noqa: F405
        if done:
            print("[Plan] Terminated during move-to-side.")
            return

        # ------------------------------------------------------------------
        # STEP 3  – move-to-anchor(bottom_anchor_pos)
        # ------------------------------------------------------------------
        print("\n[Plan] Step 3 – move to bottom_anchor_pos")
        anchor_pos = _safe_get_pos("bottom_anchor_pos", refresh_positions())
        obs, reward, done = move(env, task, anchor_pos)                # noqa: F405
        if done:
            print("[Plan] Terminated during move-to-anchor.")
            return

        # ------------------------------------------------------------------
        # STEP 4  – pick-drawer   (grasp handle located at anchor_pos)
        # ------------------------------------------------------------------
        print("\n[Plan] Step 4 – pick drawer handle")
        obs, reward, done = pick(env, task,                              # noqa: F405
                                 target_pos=anchor_pos,
                                 approach_distance=0.05,
                                 approach_axis='-z')
        if done:
            print("[Plan] Terminated during pick-drawer.")
            return

        # ------------------------------------------------------------------
        # STEP 5  – pull(bottom drawer) 0.20 m in +X direction
        # ------------------------------------------------------------------
        print("\n[Plan] Step 5 – pull drawer 0.20 m along +X")
        obs, reward, done = pull(env, task,                              # noqa: F405
                                 pull_distance=0.20,
                                 pull_axis='x')
        if done:
            print("[Plan] Terminated during pull.")
            return

        # ------------------------------------------------------------------
        # STEP 6  – pick first tomato
        # ------------------------------------------------------------------
        print("\n[Plan] Step 6 – pick first tomato")
        tomato_positions = _discover_tomatoes(refresh_positions(), 2)
        tomato1_pos, tomato2_pos = tomato_positions
        obs, reward, done = pick(env, task,                              # noqa: F405
                                 target_pos=tomato1_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Plan] Terminated while picking first tomato.")
            return

        # ------------------------------------------------------------------
        # STEP 7  – place first tomato on plate
        # ------------------------------------------------------------------
        print("\n[Plan] Step 7 – place first tomato on plate")
        plate_pos = _safe_get_pos("plate", refresh_positions())
        obs, reward, done = place(env, task,                             # noqa: F405
                                  target_pos=plate_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Plan] Terminated while placing first tomato.")
            return

        # ------------------------------------------------------------------
        # STEP 8  – pick second tomato
        # ------------------------------------------------------------------
        print("\n[Plan] Step 8 – pick second tomato")
        # Tomato might have moved slightly when drawer opened or after first pick
        tomato2_pos = _safe_get_pos("item2", refresh_positions()) \
            if "item2" in refresh_positions() else tomato2_pos
        obs, reward, done = pick(env, task,                              # noqa: F405
                                 target_pos=tomato2_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Plan] Terminated while picking second tomato.")
            return

        # ------------------------------------------------------------------
        # STEP 9  – place second tomato on plate
        # ------------------------------------------------------------------
        print("\n[Plan] Step 9 – place second tomato on plate")
        plate_pos = _safe_get_pos("plate", refresh_positions())
        obs, reward, done = place(env, task,                             # noqa: F405
                                  target_pos=plate_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Plan] Task signalled completion – success!")
        else:
            print("[Plan] High-level plan executed.  Environment did not "
                  "signal done, but objectives should now be satisfied.")

    except Exception as exc:                                            # pylint: disable=broad-except
        print(f"[ERROR] Exception encountered: {exc}")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()