import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pick, pull, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def quat_from_euler(roll: float, pitch: float, yaw: float, seq: str = 'xyz'):
    """Utility: convert Euler angles (deg) to xyzw quaternion."""
    return R.from_euler(seq, [roll, pitch, yaw], degrees=True).as_quat()


def compute_pull_params(anchor_pos: np.ndarray, joint_pos: np.ndarray):
    """
    Decide the pull direction and distance from the relative offset of the drawer’s
    joint position (at the back of the drawer) and the drawer‑handle anchor position
    (at the front of the drawer).
    """
    diff = joint_pos - anchor_pos
    # Choose the axis with the largest absolute offset
    axis_index = int(np.argmax(np.abs(diff)))
    axis_names = ['x', 'y', 'z']
    axis = axis_names[axis_index]
    sign = '' if diff[axis_index] >= 0 else '-'
    pull_axis = f'{sign}{axis}'
    pull_distance = np.abs(diff[axis_index])
    # If the environment reports an extremely small offset, give a sensible default
    if pull_distance < 1e-3:
        pull_distance = 0.10
    return pull_axis, pull_distance


def run_oracle_plan():
    """
    Execute the oracle plan specified in the problem description using predefined
    skills only.  No new low‑level motion primitives are introduced here.
    """
    print("===== ORACLE PLAN: start =====")
    env, task = setup_environment()

    try:
        # ----  Environment / video initialisation  ----
        _, obs = task.reset()
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ----  Gather all relevant object poses  ----
        positions = get_object_positions()

        # Drawer positions
        side_pos_bottom   = np.array(positions['bottom_side_pos'])
        anchor_pos_bottom = np.array(positions['bottom_anchor_pos'])
        joint_pos_bottom  = np.array(positions['bottom_joint_pos'])

        # Objects to manipulate afterwards
        tomato1_pos = np.array(positions['tomato1'])
        tomato2_pos = np.array(positions['tomato2'])
        plate_pos   = np.array(positions['plate'])

        # ----  Step‑by‑step oracle plan  ----

        # 1) rotate gripper from “zero_deg” to “ninety_deg”
        target_quat = quat_from_euler(0, 0, 90)   # 90° about Z
        print("\n[Plan‑1] rotate gripper to 90° yaw")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Early‑Exit] task finished unexpectedly after rotate")
            return

        # 2) move gripper to the drawer’s side approach pose
        print("\n[Plan‑2] move gripper to drawer side‑pos‑bottom")
        obs, reward, done = move(env, task, side_pos_bottom)
        if done:
            print("[Early‑Exit] task finished unexpectedly after move‑to‑side")
            return

        # 3) move gripper onto the drawer‑handle anchor pose
        print("\n[Plan‑3] move gripper to drawer anchor‑pos‑bottom")
        obs, reward, done = move(env, task, anchor_pos_bottom)
        if done:
            print("[Early‑Exit] task finished unexpectedly after move‑to‑anchor")
            return

        # 4) grasp the drawer’s handle (“pick‑drawer” logically)
        print("\n[Plan‑4] grasp (pick) the drawer handle")
        obs, reward, done = pick(env, task, anchor_pos_bottom, approach_axis='z')
        if done:
            print("[Early‑Exit] task finished unexpectedly after pick‑drawer")
            return

        # 5) pull the drawer open
        pull_axis, pull_distance = compute_pull_params(anchor_pos_bottom, joint_pos_bottom)
        print(f"\n[Plan‑5] pull drawer: axis={pull_axis}, distance={pull_distance:.3f}")
        obs, reward, done = pull(env, task, pull_distance=pull_distance,
                                 pull_axis=pull_axis)
        if done:
            print("[Early‑Exit] task finished unexpectedly after pull")
            return

        # 6) pick first tomato
        print("\n[Plan‑6] pick tomato1")
        obs, reward, done = pick(env, task, tomato1_pos, approach_axis='z')
        if done:
            print("[Early‑Exit] task finished unexpectedly after pick tomato1")
            return

        # 7) place first tomato onto plate
        print("\n[Plan‑7] place tomato1 on plate")
        obs, reward, done = place(env, task, plate_pos, approach_axis='z')
        if done:
            print("[Early‑Exit] task finished unexpectedly after place tomato1")
            return

        # 8) pick second tomato
        print("\n[Plan‑8] pick tomato2")
        obs, reward, done = pick(env, task, tomato2_pos, approach_axis='z')
        if done:
            print("[Early‑Exit] task finished unexpectedly after pick tomato2")
            return

        # 9) place second tomato onto plate
        print("\n[Plan‑9] place tomato2 on plate")
        obs, reward, done = place(env, task, plate_pos, approach_axis='z')
        if done:
            print("[Completed] Goal achieved with reward:", reward)
        else:
            print("[Warning] Goal execution ended with done=False")

    except Exception as e:
        print("Exception during oracle plan execution:", str(e))
        raise
    finally:
        shutdown_environment(env)
        print("===== ORACLE PLAN: end =====")


if __name__ == '__main__':
    run_oracle_plan()