# run_task.py
#
# Executable task script that follows the oracle plan listed in the
# specification.  ONLY the already-implemented skills from skill_code.py
# are invoked – nothing new is defined here.
#
# High-level oracle plan (7 steps, matching the specification):
#   1) rotate      – (gripper  zero_deg → ninety_deg)
#   2) move        – to drawer “side” position
#   3) move        – to drawer “anchor” position
#   4) pick        – grasp the drawer handle
#   5) pull        – pull the drawer open
#   6) pick        – pick the rubbish object from the table
#   7) place       – drop the rubbish into the bin

import numpy as np
from pyrep.objects.shape import Shape                 # noqa: F401
from pyrep.objects.proximity_sensor import ProximitySensor   # noqa: F401

from env import setup_environment, shutdown_environment
from skill_code import move, rotate, pick, pull, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
#  Helper utilities                                                           #
# --------------------------------------------------------------------------- #
def _safe_get(pos_dict, *keys):
    """Return the value associated with the first existing key."""
    for k in keys:
        if k in pos_dict:
            return pos_dict[k]
    raise KeyError(f'None of the keys {keys} exist in object_positions.')


def _quat_from_z_rotation(deg: float) -> np.ndarray:
    """Quaternion (x,y,z,w) for a pure Z-axis rotation."""
    rad = np.deg2rad(deg)
    return np.array([0.0, 0.0, np.sin(rad / 2.0), np.cos(rad / 2.0)],
                    dtype=np.float32)


# --------------------------------------------------------------------------- #
#  Main task                                                                  #
# --------------------------------------------------------------------------- #
def run_task() -> None:
    print('\n======================  START TASK  ======================\n')

    env, task = setup_environment()
    try:
        # -----------------------------------------------------------
        #  Environment reset & optional video recording
        # -----------------------------------------------------------
        _, obs = task.reset()
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -----------------------------------------------------------
        #  Retrieve object positions (names depend on the RLBench
        #  scene – we fall back to common alternatives when needed)
        # -----------------------------------------------------------
        pos = get_object_positions()

        # Drawer positions – use the bottom drawer (known unlocked)
        drawer_side   = _safe_get(pos,
                                  'bottom_side_pos',
                                  'bottom_side',
                                  'bottom_drawer_side')
        drawer_anchor = _safe_get(pos,
                                  'bottom_anchor_pos',
                                  'bottom_anchor',
                                  'bottom_drawer_anchor')

        # Rubbish object: first try explicit “rubbish”, else item1/2/3
        rubbish_name = 'rubbish'
        if rubbish_name not in pos:
            for alt in ('item3', 'item2', 'item1'):
                if alt in pos:
                    rubbish_name = alt
                    break
        rubbish_pos = pos[rubbish_name]

        # Bin position
        bin_pos = _safe_get(pos, 'bin', 'bin_pos')

        # -----------------------------------------------------------
        #  Execute the oracle plan step-by-step
        # -----------------------------------------------------------
        done, reward = False, 0.0

        # STEP-1  rotate gripper from 0° → 90° about Z
        if not done:
            print('\n[STEP 1] rotate gripper (0° → 90° about Z)')
            target_quat = _quat_from_z_rotation(90.0)
            _, _, done = rotate(env, task, target_quat)
        if done:
            raise RuntimeError('[STEP 1] Task terminated unexpectedly.')

        # STEP-2  move to drawer “side” approach pose
        if not done:
            print('\n[STEP 2] move → drawer_side')
            _, _, done = move(env, task, target_pos=np.asarray(drawer_side))
        if done:
            raise RuntimeError('[STEP 2] Task terminated unexpectedly.')

        # STEP-3  move to drawer “anchor” (handle) pose
        if not done:
            print('\n[STEP 3] move → drawer_anchor')
            _, _, done = move(env, task, target_pos=np.asarray(drawer_anchor))
        if done:
            raise RuntimeError('[STEP 3] Task terminated unexpectedly.')

        # STEP-4  pick / grasp the drawer handle
        if not done:
            print('\n[STEP 4] pick → drawer handle')
            _, _, done = pick(env,
                              task,
                              target_pos=np.asarray(drawer_anchor),
                              approach_distance=0.10,
                              approach_axis='-y')      # approach from front
        if done:
            raise RuntimeError('[STEP 4] Task terminated unexpectedly.')

        # STEP-5  pull the drawer open (−X direction, ~0.25 m)
        if not done:
            print('\n[STEP 5] pull → open drawer')
            _, _, done = pull(env,
                              task,
                              pull_distance=0.25,
                              pull_axis='-x')
        if done:
            raise RuntimeError('[STEP 5] Task terminated unexpectedly.')

        # STEP-6  pick the rubbish object from the table
        if not done:
            print('\n[STEP 6] move → rubbish')
            _, _, done = move(env, task, target_pos=np.asarray(rubbish_pos))
        if not done:
            print('[STEP 6] pick → rubbish')
            _, _, done = pick(env,
                              task,
                              target_pos=np.asarray(rubbish_pos),
                              approach_distance=0.15,
                              approach_axis='z')
        if done:
            raise RuntimeError('[STEP 6] Task terminated unexpectedly.')

        # STEP-7  place rubbish into the bin
        if not done:
            print('\n[STEP 7] move → bin')
            _, _, done = move(env, task, target_pos=np.asarray(bin_pos))
        if not done:
            print('[STEP 7] place → bin')
            obs, reward, done = place(env,
                                      task,
                                      target_pos=np.asarray(bin_pos),
                                      approach_distance=0.15,
                                      approach_axis='z')

        # -----------------------------------------------------------
        #  Final status
        # -----------------------------------------------------------
        if done:
            print(f'\n>>>   TASK SUCCESS  – reward: {reward}')
        else:
            # Some RLBench tasks do not set done=True on success; log anyway.
            print('\n>>>   Task finished (done=False).  Success may still be registered.')

    except Exception as exc:
        print('!!!  Exception during task execution:', exc)
        raise
    finally:
        shutdown_environment(env)
        print('\n=======================  END TASK  =======================\n')


if __name__ == '__main__':
    run_task()