import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # (you don't need to redefine primitives like move, pick, place)

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

# Extra import needed to build quaternions for the rotate skill
from scipy.spatial.transform import Rotation as R


# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------

def fetch_position(positions_dict, *name_variants):
    """
    Try several key variants and return the first position found.

    Args:
        positions_dict (dict): {"obj_name": np.ndarray([x, y, z]), ...}
        *name_variants (str): Possible keys to look up in the dictionary.

    Returns:
        np.ndarray: 3‑vector position.

    Raises:
        KeyError: If none of the variants exist in the dictionary.
    """
    for n in name_variants:
        if n in positions_dict:
            return positions_dict[n]
    raise KeyError(f"None of the variants {name_variants} found in positions dictionary.")


# ---------------------------------------------------------------------------
# Main task runner
# ---------------------------------------------------------------------------

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task
        descriptions, obs = task.reset()

        # Optional video writers (safe even if not used later)
        init_video_writers(obs)

        # Wrap step / get_observation for video recording
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # Retrieve all object positions known to the helper
        positions = get_object_positions()

        # ------------------------------------------------------------
        # Mapping between PDDL names (spec) and RLBench object names
        # ------------------------------------------------------------
        name_map = {
            # Drawer related
            "nowhere-pos": ("nowhere_pos", "waypoint1"),               # starting waypoint
            "side-pos-bottom": ("bottom_side_pos",),
            "anchor-pos-bottom": ("bottom_anchor_pos",),
            # Rubbish disposal
            "rubbish": ("rubbish",),
            "bin": ("bin",),
        }

        # Convenience lambda
        p = lambda pddl_name: fetch_position(positions, *name_map[pddl_name])

        # ------------------------------------------------------------
        # Execute Oracle Plan (Spec.steps)
        # ------------------------------------------------------------
        done = False
        reward = 0.0

        # -------- Step 1: rotate(gripper, zero_deg, ninety_deg) -------
        print("\n[Plan] Step 1 – Rotate gripper from zero_deg to ninety_deg")
        # Build target quaternion (rotate 90° around Z)
        target_quat = R.from_euler('xyz', [0, 0, 90], degrees=True).as_quat()
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Plan] Task finished after rotate.")
            return

        # -------- Step 2: move(gripper, bottom, nowhere-pos, side-pos-bottom) -------
        print("\n[Plan] Step 2 – Move gripper to drawer’s side position")
        target_pos_side = p("side-pos-bottom")
        obs, reward, done = move(env, task, target_pos_side)
        if done:
            print("[Plan] Task finished after move #2.")
            return

        # -------- Step 3: move(gripper, bottom, side-pos-bottom, anchor-pos-bottom) -------
        print("\n[Plan] Step 3 – Move gripper from side to anchor position")
        target_pos_anchor = p("anchor-pos-bottom")
        obs, reward, done = move(env, task, target_pos_anchor)
        if done:
            print("[Plan] Task finished after move #3.")
            return

        # -------- Step 4: pick‑drawer(gripper, bottom, anchor-pos-bottom) -------
        print("\n[Plan] Step 4 – Grip the drawer handle (pick‑drawer equivalent)")
        obs, reward, done = pick(
            env,
            task,
            target_pos=target_pos_anchor,
            approach_distance=0.10,
            approach_axis='y'  # choose axis orthogonal to drawer face (adjust if needed)
        )
        if done:
            print("[Plan] Task finished after pick‑drawer.")
            return

        # -------- Step 5: pull(gripper, bottom) -------
        print("\n[Plan] Step 5 – Pull the drawer open")
        # Pull the drawer straight out along +x by 0.20 m (distance may vary by scenario)
        obs, reward, done = pull(
            env,
            task,
            pull_distance=0.20,
            pull_axis='x'
        )
        if done:
            print("[Plan] Task finished after pull.")
            return

        # -------- Step 6: pick(rubbish, table) -------
        print("\n[Plan] Step 6 – Pick the rubbish from the table")
        target_pos_rubbish = p("rubbish")
        obs, reward, done = pick(
            env,
            task,
            target_pos=target_pos_rubbish,
            approach_distance=0.15,
            approach_axis='z'
        )
        if done:
            print("[Plan] Task finished after picking rubbish.")
            return

        # -------- Step 7: place(rubbish, bin) -------
        print("\n[Plan] Step 7 – Place the rubbish into the bin")
        target_pos_bin = p("bin")
        obs, reward, done = place(
            env,
            task,
            target_pos=target_pos_bin,
            approach_distance=0.15,
            approach_axis='z'
        )
        if done:
            print("[Plan] Task finished after placing rubbish.")
        else:
            print("[Plan] Plan completed – environment reports done = False (no terminal signal).")
        print(f"[Plan] Final reward: {reward}")

    except KeyError as ke:
        print(f"[Error] Failed to resolve object position: {ke}")
    except Exception as e:
        print(f"[Error] Unexpected exception occurred: {e}")
    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
