# run_combined_task.py
#
# Executable task script that follows the oracle plan defined in the
# “Specification” section.  It relies exclusively on the pre-implemented
# skills imported from skill_code.

import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# --------------------------------------------------------------------------- #
# Import the predefined skills (DO NOT redefine them)                         #
# --------------------------------------------------------------------------- #
from skill_code import (
    pick,
    place,
    move,
    rotate,
    pull,
    normalize_quaternion,   # utility already defined in skill_code
)

# --------------------------------------------------------------------------- #
# Helper utilities                                                            #
# --------------------------------------------------------------------------- #
def compose_relative_rotation(orig_quat_xyzw, axis: str, deg: float):
    """Return a quaternion obtained by rotating `orig_quat_xyzw`
    `deg` degrees about the specified axis."""
    assert axis in ['x', 'y', 'z'], 'Axis must be x, y, or z.'
    base_rot  = R.from_quat(orig_quat_xyzw)
    delta_rot = R.from_euler(axis, deg, degrees=True)
    new_rot   = (base_rot * delta_rot).as_quat()
    return normalize_quaternion(new_rot)


def safe_skill_call(skill_fn, *args, **kwargs):
    """Call a skill and handle any unexpected errors gracefully."""
    try:
        return skill_fn(*args, **kwargs)
    except Exception as exc:
        # Log the error, then re-raise so that the outer finally block still executes
        print(f"[ERROR] Exception during skill '{skill_fn.__name__}': {exc}")
        raise


# --------------------------------------------------------------------------- #
# Main task logic                                                             #
# --------------------------------------------------------------------------- #
def run_combined_task():
    print("\n================  STARTING COMBINED TASK  ================\n")

    # ------------------------------------------------------------------- #
    # 0) Environment set-up                                               #
    # ------------------------------------------------------------------- #
    env, task = setup_environment()
    try:
        # Reset the task and obtain the first observation
        descriptions, obs = task.reset()

        # Optional: video recording helpers
        init_video_writers(obs)
        task.step           = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # Retrieve static object positions (dictionary of np.array)
        positions = {k: np.asarray(v) for k, v in get_object_positions().items()}

        # Convenience helper to deal with possible naming variants
        def get_pos(*keys):
            for k in keys:
                if k in positions:
                    return positions[k]
            return None

        # ---------------------------------------------------------------- #
        # Step-1: rotate(gripper, zero_deg → ninety_deg)                    #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-1] ROTATE gripper to 90° about z-axis ---")
        init_quat   = normalize_quaternion(obs.gripper_pose[3:7])
        target_quat = compose_relative_rotation(init_quat, axis='z', deg=90.0)

        obs, reward, done = safe_skill_call(
            rotate,
            env=env,
            task=task,
            target_quat=target_quat,
        )
        if done:
            print("[Task] Episode finished unexpectedly after Step-1.")
            return

        # ---------------------------------------------------------------- #
        # Step-2: move(nowhere-pos → side-pos-bottom)                       #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-2] MOVE to drawer SIDE position (bottom) ---")
        side_pos = get_pos('bottom_side_pos', 'side-pos-bottom')
        if side_pos is None:
            raise RuntimeError("Position for 'bottom_side_pos' not found.")
        obs, reward, done = safe_skill_call(
            move,
            env=env,
            task=task,
            target_pos=side_pos,
        )
        if done:
            print("[Task] Episode finished unexpectedly after Step-2.")
            return

        # ---------------------------------------------------------------- #
        # Step-3: move(side → anchor-pos-bottom)                            #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-3] MOVE to drawer ANCHOR position (bottom) ---")
        anchor_pos = get_pos('bottom_anchor_pos', 'anchor-pos-bottom')
        if anchor_pos is None:
            raise RuntimeError("Position for 'bottom_anchor_pos' not found.")
        obs, reward, done = safe_skill_call(
            move,
            env=env,
            task=task,
            target_pos=anchor_pos,
        )
        if done:
            print("[Task] Episode finished unexpectedly after Step-3.")
            return

        # ---------------------------------------------------------------- #
        # Step-4: pick-drawer  (handled via generic `pick`)                 #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-4] PICK the drawer handle (bottom) ---")
        obs, reward, done = safe_skill_call(
            pick,
            env=env,
            task=task,
            target_pos=anchor_pos,
            approach_distance=0.10,   # shorter approach to handle
            approach_axis='y',        # approach from front
        )
        if done:
            print("[Task] Episode finished unexpectedly after Step-4.")
            return

        # ---------------------------------------------------------------- #
        # Step-5: pull to open the drawer fully                             #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-5] PULL to open the drawer fully ---")
        obs, reward, done = safe_skill_call(
            pull,
            env=env,
            task=task,
            pull_distance=0.22,        # tuned distance; adjust if necessary
            pull_axis='x',             # positive-X assumed outward
        )
        if done:
            print("[Task] Episode finished unexpectedly after Step-5.")
            return

        # ---------------------------------------------------------------- #
        # Step-6 & 7: pick tomato1 → place on plate                         #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-6 & 7] PICK tomato1 and PLACE on plate ---")
        tomato1_pos = get_pos('tomato1')
        plate_pos   = get_pos('plate')
        if tomato1_pos is None or plate_pos is None:
            raise RuntimeError("Positions for 'tomato1' or 'plate' missing.")

        # Pick tomato1
        obs, reward, done = safe_skill_call(
            pick,
            env=env,
            task=task,
            target_pos=tomato1_pos,
            approach_distance=0.15,
            approach_axis='z',
        )
        if done:
            print("[Task] Episode finished unexpectedly while picking tomato1.")
            return

        # Place tomato1
        obs, reward, done = safe_skill_call(
            place,
            env=env,
            task=task,
            target_pos=plate_pos,
            approach_distance=0.15,
            approach_axis='z',
        )
        if done:
            print("[Task] Episode finished unexpectedly while placing tomato1.")
            return

        # ---------------------------------------------------------------- #
        # Step-8 & 9: pick tomato2 → place on plate                         #
        # ---------------------------------------------------------------- #
        print("\n--- [Step-8 & 9] PICK tomato2 and PLACE on plate ---")
        tomato2_pos = get_pos('tomato2')
        if tomato2_pos is None:
            raise RuntimeError("Position for 'tomato2' missing.")

        # Pick tomato2
        obs, reward, done = safe_skill_call(
            pick,
            env=env,
            task=task,
            target_pos=tomato2_pos,
            approach_distance=0.15,
            approach_axis='z',
        )
        if done:
            print("[Task] Episode finished unexpectedly while picking tomato2.")
            return

        # Place tomato2
        obs, reward, done = safe_skill_call(
            place,
            env=env,
            task=task,
            target_pos=plate_pos,
            approach_distance=0.15,
            approach_axis='z',
        )

        # ---------------------------------------------------------------- #
        # SUCCESS CHECK                                                    #
        # ---------------------------------------------------------------- #
        if done:
            print("\n[Task] SUCCESS – goal achieved!  Reward:", reward)
        else:
            print("\n[Task] Plan finished, environment did not flag 'done'.")
            print("        Manually verify success predicates if needed.")

    finally:
        shutdown_environment(env)
        print("\n================  ENVIRONMENT SHUTDOWN  ================\n")


# --------------------------------------------------------------------------- #
# Entry point                                                                 #
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    run_combined_task()