import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ---- Import predefined skills (DO NOT redefine!) ----
from skill_code import pick, place, move, rotate, pull


# ---------------------------------------------------------------------------
#  Utility – quaternion for pure Z-axis rotation
# ---------------------------------------------------------------------------
def quaternion_from_euler_z(deg: float) -> np.ndarray:
    """Return an (x,y,z,w) quaternion representing a Z-axis rotation."""
    return R.from_euler('z', deg, degrees=True).as_quat()


# ---------------------------------------------------------------------------
#  Environment set-up
# ---------------------------------------------------------------------------
env, task = setup_environment()
done = False          # global episode-termination flag

try:
    # Reset task
    _, obs = task.reset()

    # ----- optional video recording -------------------------------------------------
    init_video_writers(obs)
    task.step            = recording_step(task.step)
    task.get_observation = recording_get_observation(task.get_observation)
    # --------------------------------------------------------------------------------

    # --------------------------------------------------
    #  Retrieve all required object positions
    # --------------------------------------------------
    positions = get_object_positions()
    required_keys = [
        'middle_side_pos', 'middle_anchor_pos',
        'bottom_side_pos', 'bottom_anchor_pos',
        'rubbish', 'bin', 'waypoint1'
    ]
    for k in required_keys:
        if k not in positions:
            raise KeyError(f"[Init] Missing key in object_positions: '{k}'")

    middle_side   = np.asarray(positions['middle_side_pos'])
    middle_anchor = np.asarray(positions['middle_anchor_pos'])
    bottom_side   = np.asarray(positions['bottom_side_pos'])
    bottom_anchor = np.asarray(positions['bottom_anchor_pos'])
    rubbish_pos   = np.asarray(positions['rubbish'])
    bin_pos       = np.asarray(positions['bin'])
    neutral_wp    = np.asarray(positions['waypoint1'])

    ninety_deg_quat = quaternion_from_euler_z(90.0)

except Exception as e:
    print(f"[Init] Exception during environment set-up: {e}")
    shutdown_environment(env)
    raise


# ---------------------------------------------------------------------------
#  ---------------  F r o z e n   C o d e   S t a r t  ----------------------
# ---------------------------------------------------------------------------
obs, reward, done = rotate(env, task, target_quat=ninety_deg_quat)
obs, reward, done = move(env, task, target_pos=middle_side)
obs, reward, done = move(env, task, target_pos=middle_anchor)
obs, reward, done = pick(env, task, target_pos=middle_anchor,
                                 approach_distance=0.08, approach_axis='-z')
obs, reward, done = pull(env, task,
                                 pull_distance=0.20,   # 20 cm pull-out
                                 pull_axis='x')
obs, reward, done = rotate(env, task, target_quat=ninety_deg_quat)
obs, reward, done = move(env, task, target_pos=bottom_side)
obs, reward, done = move(env, task, target_pos=bottom_anchor)
obs, reward, done = pick(env, task, target_pos=bottom_anchor,
                                 approach_distance=0.08, approach_axis='-z')
obs, reward, done = pull(env, task,
                                 pull_distance=0.20,
                                 pull_axis='x')
# ---------------------------------------------------------------------------
#  ---------------  F r o z e n   C o d e   E n d    ------------------------
# ---------------------------------------------------------------------------


# ---------------------------------------------------------------------------
#  Remaining steps to complete the overall goal
# ---------------------------------------------------------------------------
try:
    if done:
        print("[Plan] Episode terminated during frozen snippet – aborting remainder.")
    else:
        # --------------------------------------------------
        #  1) Release the drawer handle (simply open gripper)
        # --------------------------------------------------
        print("[Plan] Release drawer handle")
        # Place skill will move a few centimetres back toward the handle anchor
        obs, reward, done = place(env, task,
                                  target_pos=bottom_anchor,
                                  approach_distance=0.02,
                                  approach_axis='-z')
        if done:
            print("[Plan] Episode finished early while releasing handle.")
            shutdown_environment(env)
            exit(0)

        # --------------------------------------------------
        #  2) Retreat to a neutral waypoint
        # --------------------------------------------------
        print("[Plan] Move to neutral waypoint")
        obs, reward, done = move(env, task, target_pos=neutral_wp)
        if done:
            print("[Plan] Episode finished early while retreating.")
            shutdown_environment(env)
            exit(0)

        # --------------------------------------------------
        #  3) Pick the rubbish lying on the table
        # --------------------------------------------------
        print("[Plan] Pick the rubbish")
        obs, reward, done = pick(env, task,
                                 target_pos=rubbish_pos,
                                 approach_distance=0.12,
                                 approach_axis='-z')
        if done:
            print("[Plan] Episode finished early after picking rubbish.")
            shutdown_environment(env)
            exit(0)

        # --------------------------------------------------
        #  4) Place the rubbish into the bin
        # --------------------------------------------------
        print("[Plan] Place the rubbish into the bin")
        obs, reward, done = place(env, task,
                                  target_pos=bin_pos,
                                  approach_distance=0.12,
                                  approach_axis='-z')
        if done:
            print("[Plan] SUCCESS – Goal achieved!")
        else:
            print("[Plan] Plan executed – episode not terminated (done=False).")

finally:
    shutdown_environment(env)
    print("[Shutdown] Environment closed. Bye!")