import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ----  Skill imports (already contain: move, pick, place, rotate, pull) ----
from skill_code import *


# ----------------------------- Helper Utils ----------------------------- #
def _symbolic_to_real_name(symbolic_name: str) -> str:
    """
    Convert symbolic PDDL-style names to the actual RLBench object names that
    appear in `object_positions`.  Only the names we need for this task are
    mapped; extend if necessary.
    """
    mapping = {
        # drawer related
        "side-pos-bottom":      "bottom_side_pos",
        "anchor-pos-bottom":    "bottom_anchor_pos",
        "side-pos-middle":      "middle_side_pos",
        "anchor-pos-middle":    "middle_anchor_pos",
        "side-pos-top":         "top_side_pos",
        "anchor-pos-top":       "top_anchor_pos",
        "nowhere-pos":          "waypoint1",

        # “real” objects keep their names
        "rubbish":              "rubbish",
        "bin":                  "bin",
    }
    return mapping.get(symbolic_name, symbolic_name)


def _vec_to_axis(v: np.ndarray) -> (str, float):
    """
    Given a 3‑D vector, return (‘axis-string’, distance) where axis‑string is in
    {x, -x, y, -y, z, -z}.  The chosen axis is the one with the largest
    absolute value component.  Distance is the magnitude along that axis (pos).
    """
    axis_idx = int(np.argmax(np.abs(v)))
    axis_char = ['x', 'y', 'z'][axis_idx]
    sign = '' if v[axis_idx] >= 0 else '-'
    axis_str = f'{sign}{axis_char}'
    distance = np.abs(v[axis_idx])
    return axis_str, distance


# ----------------------------- Main Routine ----------------------------- #
def run_combined_task():
    """Execute the oracle plan to open a drawer then throw rubbish in the bin."""
    print("===========   START COMBINED TASK   ===========")

    env, task = setup_environment()
    try:
        # ---------------  Environment / Video  ---------------
        descriptions, obs = task.reset()
        init_video_writers(obs)                       # optional video
        task.step = recording_step(task.step)         # wrap step for video
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------------  Positions & Names  ---------------
        positions = get_object_positions()            # dict: name -> np.array([x,y,z])
        get_pos = lambda sym: np.array(
            positions[_symbolic_to_real_name(sym)]
        )

        # ---- Drawer choice: use “bottom” drawer ----
        side_pos   = get_pos("side-pos-bottom")       # side handle approach
        anchor_pos = get_pos("anchor-pos-bottom")     # handle grab point
        joint_pos  = positions.get("bottom_joint_pos", anchor_pos)  # fallback safe

        rubbish_pos = get_pos("rubbish")
        bin_pos     = get_pos("bin")
        nowhere_pos = get_pos("nowhere-pos")

        # ---------------  PLAN EXECUTION  ---------------
        done = False
        reward = 0.0

        # Step‑1 : rotate gripper from zero_deg to ninety_deg (about z)
        target_quat = R.from_euler('z', 90, degrees=True).as_quat()
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Abort] Episode finished during rotate.")
            return

        # Step‑2 : move gripper from “nowhere” to side position of bottom drawer
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("[Abort] Episode finished during move to side‑pos.")
            return

        # Step‑3 : move from side‑pos to anchor‑pos (front of drawer handle)
        obs, reward, done = move(env, task, anchor_pos)
        if done:
            print("[Abort] Episode finished during move to anchor‑pos.")
            return

        # Step‑4 : pick drawer handle (close gripper)
        obs, reward, done = pick(
            env, task, target_pos=anchor_pos,
            approach_distance=0.08, approach_axis='y'
        )
        if done:
            print("[Abort] Episode finished during handle pick.")
            return

        # Step‑5 : pull drawer open
        pull_axis, pull_distance = _vec_to_axis(anchor_pos - joint_pos)
        # Add a safety scale (open more)
        pull_distance += 0.05
        obs, reward, done = pull(
            env, task, pull_distance=pull_distance, pull_axis=pull_axis
        )
        if done:
            print("[Abort] Episode finished during drawer pull.")
            return

        # (Optional) back away slightly so the handle release is safer
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("[Abort] Episode finished during retreat from drawer.")
            return

        # Open gripper to release handle
        # Re‑use place() with tiny delta just to open gripper
        small_above = side_pos + np.array([0, 0, 0.03])
        obs, reward, done = place(
            env, task, target_pos=small_above,
            approach_distance=0.0, approach_axis='z'
        )
        if done:
            print("[Abort] Episode finished when releasing handle.")
            return

        # Step‑6 : pick the rubbish from the table
        obs, reward, done = move(env, task, rubbish_pos + np.array([0, 0, 0.20]))  # hover
        if done:
            return
        obs, reward, done = pick(
            env, task, target_pos=rubbish_pos,
            approach_distance=0.15, approach_axis='z'
        )
        if done:
            print("[Abort] Episode finished during rubbish pick.")
            return

        # Step‑7 : place rubbish in the bin
        obs, reward, done = place(
            env, task, target_pos=bin_pos,
            approach_distance=0.15, approach_axis='z'
        )
        if done:
            print("[Success] Task completed after placing rubbish!")
        else:
            print("[Info] Task finished plan but environment not signalling done yet.")
    finally:
        shutdown_environment(env)
        print("===========   END COMBINED TASK   ===========")


if __name__ == "__main__":
    run_combined_task()
