import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
# keep wildcard so that all low-level primitives remain in scope
from skill_code import *

from video import (
    init_video_writers,
    recording_step,
    recording_get_observation,
)

from object_positions import get_object_positions


# ------------------------------------------------------------------
# Helper utilities
# ------------------------------------------------------------------
def _safe_call(skill_fn, *args, **kwargs):
    """
    Thin wrapper around a low-level skill.  If the environment signals
    termination (done=True) we abort the high-level plan immediately so
    that the simulator shuts down cleanly.
    """
    obs, reward, done = skill_fn(*args, **kwargs)
    if done:
        print("[Task] Environment reported done=True – terminating early.")
    return obs, reward, done


def _find_available_drawer(positions_dict):
    """
    Returns (drawer_name, side_pos, anchor_pos) for the first drawer that
    exposes both ‘*_side_pos’ and ‘*_anchor_pos’ keys.  Preference order
    is bottom → middle → top (matching the most common unlocked drawer).
    """
    for drawer in ("bottom", "middle", "top"):
        side_key = f"{drawer}_side_pos"
        anchor_key = f"{drawer}_anchor_pos"
        if side_key in positions_dict and anchor_key in positions_dict:
            print(f"[Task] Selected drawer: {drawer}")
            return drawer, positions_dict[side_key], positions_dict[anchor_key]
    raise RuntimeError("No suitable drawer found (need side & anchor pos).")


def _resolve_tomato_names(positions_dict):
    """
    Convenience helper – some scenes use explicit ‘tomato1/2’, others use
    generic ‘item1/2’.  This routine returns a list with exactly two
    object keys for tomatoes.
    """
    preferred = ["tomato1", "tomato2"]
    fallback = ["item1", "item2"]

    result = [name for name in preferred if name in positions_dict]
    if len(result) < 2:
        for name in fallback:
            if name in positions_dict and name not in result:
                result.append(name)
            if len(result) == 2:
                break

    if len(result) != 2:
        raise RuntimeError(
            "Could not locate two tomato objects (tomato1/2 or item1/2)."
        )
    print(f"[Task] Tomatoes resolved as: {result}")
    return result


# ------------------------------------------------------------------
# Main task runner
# ------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # --------------------------------------------------------------
        # Reset environment & set up video capture
        # --------------------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)

        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # --------------------------------------------------------------
        # One-time query of all interesting object positions
        # --------------------------------------------------------------
        positions = get_object_positions()
        print("[Task] Available position keys:", list(positions.keys()))

        # Drawer (bottom → middle → top)
        drawer_name, side_pos, anchor_pos = _find_available_drawer(positions)

        # Tomatoes & plate
        tomato_names = _resolve_tomato_names(positions)
        if "plate" not in positions:
            raise RuntimeError("Plate position not found in positions dict.")
        plate_pos = positions["plate"]

        # --------------------------------------------------------------
        # Execute the oracle plan (specification steps 1…9)
        # --------------------------------------------------------------
        done_flag = False  # will be set if simulator ends

        # 1) rotate gripper from zero_deg → ninety_deg (about Z axis)
        print("[Plan] Step-1: rotate gripper 90° about Z.")
        quat_90_z = R.from_euler("xyz", [0, 0, 90], degrees=True).as_quat()
        _, _, done_flag = _safe_call(
            rotate,
            env,
            task,
            target_quat=quat_90_z,
            max_steps=150,
            threshold=0.05,
        )
        if done_flag:
            return

        # 2) move to drawer side position
        print(f"[Plan] Step-2: move to {drawer_name} side position.")
        _, _, done_flag = _safe_call(
            move,
            env,
            task,
            target_pos=side_pos,
            max_steps=150,
            threshold=0.01,
        )
        if done_flag:
            return

        # 3) move to drawer anchor position (handle)
        print(f"[Plan] Step-3: move to {drawer_name} anchor position.")
        _, _, done_flag = _safe_call(
            move,
            env,
            task,
            target_pos=anchor_pos,
            max_steps=150,
            threshold=0.01,
        )
        if done_flag:
            return

        # 4) grasp drawer handle (pick-drawer equivalent)
        print("[Plan] Step-4: grasp the drawer handle.")
        _, _, done_flag = _safe_call(
            pick,
            env,
            task,
            target_pos=anchor_pos,
            approach_distance=0.05,  # shorter approach because handle is nearby
            approach_axis="z",
            max_steps=150,
        )
        if done_flag:
            return

        # 5) pull to open the drawer (+x direction)
        print("[Plan] Step-5: pull drawer along +x to open.")
        _, _, done_flag = _safe_call(
            pull,
            env,
            task,
            pull_distance=0.12,  # adjust if drawer needs more/less travel
            pull_axis="x",
            max_steps=150,
        )
        if done_flag:
            return

        # 6-9) loop over both tomatoes: pick from table → place on plate
        for idx, tomato in enumerate(tomato_names, start=1):
            tomato_pos = positions[tomato]

            # 6 or 8 – pick tomato
            print(f"[Plan] Step-{5 + idx*2 - 1}: pick {tomato} from table.")
            _, _, done_flag = _safe_call(
                pick,
                env,
                task,
                target_pos=tomato_pos,
                approach_distance=0.15,
                approach_axis="z",
                max_steps=200,
            )
            if done_flag:
                return

            # 7 or 9 – place tomato on plate
            print(f"[Plan] Step-{5 + idx*2}: place {tomato} onto plate.")
            _, _, done_flag = _safe_call(
                place,
                env,
                task,
                target_pos=plate_pos,
                approach_distance=0.15,
                approach_axis="z",
                max_steps=200,
            )
            if done_flag:
                return

        print("[Task] Oracle plan executed successfully – goal achieved!")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()