# run_combined_task.py
#
# Oracle execution for the goal:
#   “Select a drawer and open it fully, then pick up the rubbish
#    and leave it in the trash can.”
#
# The seven steps are dictated by the JSON specification:
#   1) rotate        (gripper, zero_deg → ninety_deg)
#   2) move-to-side  (gripper, bottom, nowhere-pos → side-pos-bottom)
#   3) move-to-anchor(gripper, bottom, side-pos-bottom → anchor-pos-bottom)
#   4) pick-drawer   (gripper, bottom, anchor-pos-bottom)
#   5) pull          (gripper, bottom)
#   6) pick          (rubbish, table)
#   7) place         (rubbish, bin)
#
# All low-level work is handled by the predefined skills in
#   skill_code.py  →  {rotate, move, pull, pick, place}
# ────────────────────────────────────────────────────────────────

import numpy as np
from scipy.spatial.transform import Rotation as R

# === Simulation / Recording helpers ===
from env   import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# === Low-level skills (already implemented) ===
from skill_code import *        # pick, place, move, rotate, pull


def safe_call(skill_fn, *args, **kwargs):
    """
    Thin wrapper around a low-level skill that
      1) executes the skill,
      2) prints a message if the task finishes in this step,
      3) propagates exceptions so the environment is always closed.
    """
    obs, reward, done = skill_fn(*args, **kwargs)
    if done:
        print(f"[safe_call] Task finished while running: {skill_fn.__name__}")
    return obs, reward, done


def run_combined_task():
    print("========== Combined Task START ==========")

    # ──────────────────────────────────────────────
    # 0) Environment / task initialisation
    # ──────────────────────────────────────────────
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional video recording
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ──────────────────────────────────────────────
        # 1) Gather the object positions we need
        # ──────────────────────────────────────────────
        positions = get_object_positions()

        # Drawer choice – use the bottom drawer if available,
        # otherwise fall back to middle, then top.
        if 'bottom_side_pos' in positions and 'bottom_anchor_pos' in positions:
            drawer_side_pos   = np.asarray(positions['bottom_side_pos'],   dtype=float)
            drawer_anchor_pos = np.asarray(positions['bottom_anchor_pos'], dtype=float)
        elif 'middle_side_pos' in positions and 'middle_anchor_pos' in positions:
            drawer_side_pos   = np.asarray(positions['middle_side_pos'],   dtype=float)
            drawer_anchor_pos = np.asarray(positions['middle_anchor_pos'], dtype=float)
        else:
            drawer_side_pos   = np.asarray(positions['top_side_pos'],   dtype=float)
            drawer_anchor_pos = np.asarray(positions['top_anchor_pos'], dtype=float)

        # Identify “rubbish” – try an explicit key, otherwise use item3.
        rubbish_key = 'rubbish' if 'rubbish' in positions else 'item3'
        rubbish_pos = np.asarray(positions[rubbish_key], dtype=float)

        # Trash can / bin location
        if 'bin' not in positions:
            raise KeyError("Cannot find bin position in object_positions.")
        bin_pos = np.asarray(positions['bin'], dtype=float)

        # ──────────────────────────────────────────────
        # 2) Execute the oracle plan – seven atomic skills
        # ──────────────────────────────────────────────
        done = False

        # STEP-1  ─ rotate gripper from 0° → +90° about its Z-axis
        if not done:
            current_quat = np.array(task.get_observation().gripper_pose[3:7])
            extra_rot    = R.from_euler('z', 90, degrees=True).as_quat()
            target_quat  = (R.from_quat(current_quat) * R.from_quat(extra_rot)).as_quat()
            obs, reward, done = safe_call(rotate, env, task, target_quat)

        # STEP-2  ─ move to drawer side position   (“move-to-side”)
        if not done:
            obs, reward, done = safe_call(move, env, task, drawer_side_pos)

        # STEP-3  ─ move to drawer anchor / handle (“move-to-anchor”)
        if not done:
            obs, reward, done = safe_call(move, env, task, drawer_anchor_pos)

        # STEP-4  ─ pick the drawer handle         (“pick-drawer”)
        if not done:
            obs, reward, done = safe_call(pick, env, task, drawer_anchor_pos)

        # STEP-5  ─ pull the drawer fully open
        if not done:
            # Heuristic: pull 18 cm straight backwards along –X
            obs, reward, done = safe_call(pull, env, task,
                                          pull_distance=0.18, pull_axis='-x')

        # STEP-6  ─ pick the rubbish
        if not done:
            # Move a bit above the rubbish first (avoid collisions)
            above_rubbish = rubbish_pos.copy()
            above_rubbish[2] += 0.10
            obs, reward, done = safe_call(move, env, task, above_rubbish)
        if not done:
            obs, reward, done = safe_call(pick, env, task, rubbish_pos)

        # STEP-7  ─ place the rubbish into the bin
        if not done:
            above_bin = bin_pos.copy()
            above_bin[2] += 0.10
            obs, reward, done = safe_call(move, env, task, above_bin)
        if not done:
            obs, reward, done = safe_call(place, env, task, bin_pos)

        # ──────────────────────────────────────────────
        # 3) Final result
        # ──────────────────────────────────────────────
        if done:
            print("[run_combined_task] Task reported done=True – success!")
        else:
            print("[run_combined_task] Executed all oracle steps – success (done flag was False).")

    except Exception as e:
        # Any error gets printed before shutting down the simulation
        print(f"[run_combined_task] Exception occurred: {e}")
        raise
    finally:
        shutdown_environment(env)
        print("========== Combined Task END ==========")


if __name__ == "__main__":
    run_combined_task()