import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# -------- predefined low-level skills --------
from skill_code import rotate, move, pick, pull, place


def _quat_from_euler(rx_deg: float, ry_deg: float, rz_deg: float):
    """Utility: convert Euler angles (deg) → quaternion (xyzw order)."""
    return R.from_euler(
        'xyz',
        [np.deg2rad(rx_deg), np.deg2rad(ry_deg), np.deg2rad(rz_deg)]
    ).as_quat()


def run_combined_task():
    """Oracle execution for: ‘open bottom drawer & move tomatoes onto plate’."""
    print('===== Starting Combined Task =====')

    # ----------------------------------------------------------------------
    # 0) Environment / video initialisation
    # ----------------------------------------------------------------------
    env, task = setup_environment()
    try:
        _, obs = task.reset()

        # video wrappers (optional – render & record for later inspection)
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # 1) Gather all relevant object poses from the helper
        # ------------------------------------------------------------------
        pos = get_object_positions()

        drawer_side_key   = 'bottom_side_pos'
        drawer_anchor_key = 'bottom_anchor_pos'
        plate_key         = 'plate'

        if drawer_side_key not in pos or drawer_anchor_key not in pos:
            raise KeyError('Drawer pose keys missing in object_positions.')

        if plate_key not in pos:
            raise KeyError('"plate" pose missing in object_positions.')

        drawer_side   = np.asarray(pos[drawer_side_key])
        drawer_anchor = np.asarray(pos[drawer_anchor_key])
        plate_pos     = np.asarray(pos[plate_key])

        # tomato / item naming may differ → accept both prefixes
        tomato_keys = [k for k in pos.keys()
                       if k.startswith('item') or k.startswith('tomato')]
        if not tomato_keys:
            raise KeyError('No tomato / item positions found.')
        tomato_keys.sort()                       # deterministic order
        tomato_positions = [np.asarray(pos[k]) for k in tomato_keys]

        # ------------------------------------------------------------------
        # 2) Execute oracle plan (exact steps from specification)
        # ------------------------------------------------------------------
        target_quat = _quat_from_euler(0, 0, 90)   # rotate → ninety_deg

        # [Frozen Code Start]
        obs, reward, done = rotate(env, task, target_quat)
        obs, reward, done = move(env, task, drawer_side)
        obs, reward, done = move(env, task, drawer_anchor)
        # [Frozen Code End]

        if done:
            print('[Task] Terminated during initial 3 steps.')
            return

        # Step 4 – grasp drawer handle (close gripper)
        obs, reward, done = pick(
            env, task, drawer_anchor,
            approach_distance=0.10, approach_axis='-z'
        )
        if done:
            print('[Task] Terminated while picking drawer handle.')
            return

        # Step 5 – pull drawer straight out along +X (≈ 25 cm)
        obs, reward, done = pull(
            env, task,
            pull_distance=0.25, pull_axis='x'
        )
        if done:
            print('[Task] Terminated during drawer pull.')
            return

        # ------------------------------------------------------------------
        # Step 6-9 – sequentially move every tomato onto plate
        # ------------------------------------------------------------------
        hover_h = 0.15           # hover height (15 cm above objects)
        for idx, t_pos in enumerate(tomato_positions, 1):
            # 6a) hover above current tomato
            obs, reward, done = move(env, task, t_pos + np.array([0, 0, hover_h]))
            if done:
                print(f'[Task] Terminated while hovering above tomato #{idx}.')
                return

            # 6b) pick tomato
            obs, reward, done = pick(
                env, task, t_pos,
                approach_distance=0.10, approach_axis='-z'
            )
            if done:
                print(f'[Task] Terminated while picking tomato #{idx}.')
                return

            # 6c) place tomato onto plate
            obs, reward, done = place(
                env, task, plate_pos,
                approach_distance=0.10, approach_axis='-z'
            )
            if done:
                # environment may signal success early; exit gracefully
                print('[Task] Environment reported completion.')
                break

        # ------------------------------------------------------------------
        # Final status
        # ------------------------------------------------------------------
        if done:
            print('[Task] Completed successfully! Reward:', reward)
        else:
            print('[Task] Oracle plan finished; environment still active.')

    finally:
        shutdown_environment(env)
        print('===== End of Combined Task =====')


if __name__ == '__main__':
    run_combined_task()