# run_task.py

import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# === Import pre–implemented skills ===
from skill_code import rotate, move, pull, pick, place


# ----------------------------------------------------------------------
# Helper utilities
# ----------------------------------------------------------------------
def normalize_quaternion(q):
    return q / np.linalg.norm(q)


def lookup_position(name, cache):
    """
    Robustly obtain the 3-D position of an object or waypoint.

    1) First look in the dictionary returned by ``get_object_positions``.
    2) If the key is not found, fall back to a direct PyRep Shape lookup.
    """
    if name in cache:
        return np.asarray(cache[name])

    # Fallback: direct query from the simulator
    try:
        from pyrep.objects.shape import Shape
        pos = Shape(name).get_position()
        cache[name] = pos
        return np.asarray(pos)
    except Exception as e:
        raise RuntimeError(f"[lookup_position] Could not find object '{name}': {e}")


# ----------------------------------------------------------------------
# Main routine that executes the oracle plan
# ----------------------------------------------------------------------
def run_task():
    print("=====  Combined-Domain Task (Open Drawer + Dispose Rubbish)  =====")

    # ------------------------------------------------------------------
    # 1)  Environment setup
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        init_video_writers(obs)                     # Optional recording

        # Wrap step / get_observation with video capture helpers
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # Snapshot of all known positions (dictionary)
        obj_pos_dict = get_object_positions()

        # Convenience lambda
        pos = lambda name: lookup_position(name, obj_pos_dict)

        # ------------------------------------------------------------------
        # 2)  Execute the oracle plan (7 steps)
        # ------------------------------------------------------------------
        # Step-1: rotate(gripper, zero_deg → ninety_deg)
        print("\n--- [Plan-Step 1] rotate gripper zero→90 deg ---")
        obs = task.get_observation()
        current_quat = normalize_quaternion(obs.gripper_pose[3:7])

        # 90-deg rotation about Z-axis
        target_quat = (R.from_euler('z', 90, degrees=True) *
                       R.from_quat(current_quat)).as_quat()
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Early-Exit] Task finished after rotate.")
            return

        # Step-2: move-to-side (nowhere-pos → bottom_side_pos)
        print("\n--- [Plan-Step 2] move to bottom_side_pos ---")
        bottom_side = pos('bottom_side_pos')  # alias for side-pos-bottom
        obs, reward, done = move(env, task, bottom_side)
        if done:
            print("[Early-Exit] Task finished after move-to-side.")
            return

        # Step-3: move-to-anchor (side_pos → bottom_anchor_pos)
        print("\n--- [Plan-Step 3] move to bottom_anchor_pos ---")
        bottom_anchor = pos('bottom_anchor_pos')   # alias for anchor-pos-bottom
        obs, reward, done = move(env, task, bottom_anchor)
        if done:
            print("[Early-Exit] Task finished after move-to-anchor.")
            return

        # Step-4: pick-drawer  (grasp drawer handle at anchor position)
        print("\n--- [Plan-Step 4] pick drawer handle ---")
        obs, reward, done = pick(env, task,
                                 target_pos=bottom_anchor,
                                 approach_distance=0.05,
                                 approach_axis='y')   # approach along drawer normal
        if done:
            print("[Early-Exit] Task finished after pick-drawer.")
            return

        # Step-5: pull (open the drawer)
        print("\n--- [Plan-Step 5] pull drawer open ---")
        obs, reward, done = pull(env, task,
                                 pull_distance=0.12,   # 12 cm pull
                                 pull_axis='x')
        if done:
            print("[Early-Exit] Task finished after pull.")
            return

        # Step-6: pick rubbish from table
        print("\n--- [Plan-Step 6] pick rubbish from table ---")
        rubbish_pos = pos('rubbish')
        obs, reward, done = pick(env, task,
                                 target_pos=rubbish_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Early-Exit] Task finished after picking rubbish.")
            return

        # Step-7: place rubbish into bin
        print("\n--- [Plan-Step 7] place rubbish into bin ---")
        bin_pos = pos('bin')
        obs, reward, done = place(env, task,
                                  target_pos=bin_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Success] Task completed! Reward:", reward)
        else:
            print("[Notice] Task ended with done=False. Check success conditions.")

    finally:
        shutdown_environment(env)
        print("=====  Task Shutdown Complete  =====")


if __name__ == "__main__":
    run_task()
