import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape

from env import setup_environment, shutdown_environment

# Import the predefined skills exactly as provided
from skill_code import pick, place, move, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------
# Helper utilities (ONLY small local helpers – NOT new robot skills)
# ---------------------------------------------------------------------
def _concat_dict(*dicts):
    """Merge dictionaries (later ones override earlier ones)."""
    merged = {}
    for d in dicts:
        if d:
            merged.update(d)
    return merged


def _vector_to_axis(v):
    """
    Convert a 3-D vector into one of the six axis labels that the pull
    primitive understands:  'x', '-x', 'y', '-y', 'z', '-z'.
    """
    axis_names = ['x', 'y', 'z']
    idx = int(np.argmax(np.abs(v)))
    sign = 1 if v[idx] >= 0 else -1
    return axis_names[idx] if sign > 0 else f'-{axis_names[idx]}'


def _acc_reward(total, r):
    """Safely accumulate rewards that might be None."""
    if r is None:
        return total
    try:
        return total + float(r)
    except Exception:
        return total


# ---------------------------------------------------------------------
# Oracle-plan execution
# ---------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # === Environment setup ===
    env, task = setup_environment()
    try:
        # Reset the RLBench task
        descriptions, obs = task.reset()

        # ---- Video hooks (optional) ----
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ---- Obtain object positions ----
        positions = get_object_positions()

        # Fallback: query PyRep directly for any missing keys
        try:
            extra_keys = [
                'bottom_anchor_pos', 'bottom_side_pos', 'bottom_joint_pos',
                'middle_anchor_pos', 'middle_side_pos', 'middle_joint_pos',
                'top_anchor_pos',    'top_side_pos',    'top_joint_pos',
                'rubbish', 'bin'
            ]
            extra = {}
            for key in extra_keys:
                if key not in positions:
                    try:
                        extra[key] = Shape(key).get_position()
                    except Exception:
                        pass
            positions = _concat_dict(positions, extra)
        except Exception:
            # If direct PyRep queries fail, just keep what we already have
            pass

        # Quick sanity-check for critical objects
        must_have = ['bottom_anchor_pos', 'bottom_side_pos',
                     'bottom_joint_pos', 'rubbish', 'bin']
        for k in must_have:
            if k not in positions:
                raise RuntimeError(
                    f"[Init] Missing position for required object '{k}'. "
                    "Check object_positions.py or scene naming."
                )

        # -----------------------------------------------------------------
        # Execute the oracle plan, step-by-step
        # -----------------------------------------------------------------
        done = False
        reward_sum = 0.0

        # STEP-1 ─ rotate(gripper, zero_deg → ninety_deg)
        print("\n[Plan] Step-1 : rotate gripper 0° → 90°")
        obs = task.get_observation()
        current_quat = obs.gripper_pose[3:7]            # xyzw
        add_quat = R.from_euler('z', 90, degrees=True).as_quat()
        target_quat = (R.from_quat(current_quat) * R.from_quat(add_quat)).as_quat()

        # -----------------------------------------------------------------
        # DO NOT MODIFY THE FOLLOWING LINE (Frozen region)
        # -----------------------------------------------------------------
        obs, reward, done = rotate(env, task, target_quat)
        # -----------------------------------------------------------------

        reward_sum = _acc_reward(reward_sum, reward)
        if done:
            print("[Plan] Task finished prematurely after rotate")
            return

        # STEP-2 ─ move-to-side (nowhere → bottom_side_pos)
        print("\n[Plan] Step-2 : move to bottom side-pos")
        side_pos = np.array(positions['bottom_side_pos'])
        obs, reward, done = move(env, task, side_pos)
        reward_sum = _acc_reward(reward_sum, reward)
        if done:
            print("[Plan] Task finished prematurely after move-to-side")
            return

        # STEP-3 ─ move-to-anchor (side → bottom_anchor_pos)
        print("\n[Plan] Step-3 : move to bottom anchor-pos")
        anchor_pos = np.array(positions['bottom_anchor_pos'])
        obs, reward, done = move(env, task, anchor_pos)
        reward_sum = _acc_reward(reward_sum, reward)
        if done:
            print("[Plan] Task finished prematurely after move-to-anchor")
            return

        # STEP-4 ─ pick-drawer (grasp drawer handle)
        print("\n[Plan] Step-4 : grasp drawer handle (pick-drawer)")
        # Re-use generic pick primitive; smaller approach distance & axis toward handle
        obs, reward, done = pick(
            env, task,
            target_pos=anchor_pos,
            approach_distance=0.06,
            max_steps=100,
            threshold=0.01,
            approach_axis='-y',     # typical handle approach
            timeout=10.0
        )
        reward_sum = _acc_reward(reward_sum, reward)
        if done:
            print("[Plan] Task finished prematurely after pick-drawer")
            return

        # STEP-5 ─ pull (open drawer)
        print("\n[Plan] Step-5 : pull drawer open")
        joint_pos = np.array(positions['bottom_joint_pos'])
        drawer_vec = anchor_pos - joint_pos
        pull_axis = _vector_to_axis(drawer_vec)
        obs, reward, done = pull(
            env, task,
            pull_distance=0.20,      # 20 cm pull
            pull_axis=pull_axis,
            max_steps=100,
            threshold=0.01,
            timeout=10.0
        )
        reward_sum = _acc_reward(reward_sum, reward)
        if done:
            print("[Plan] Task finished prematurely after pull")
            return

        # STEP-6 ─ pick(rubbish, table)
        print("\n[Plan] Step-6 : pick rubbish on table")
        rubbish_pos = np.array(positions['rubbish'])
        obs, reward, done = pick(
            env, task,
            target_pos=rubbish_pos,
            approach_distance=0.12,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        reward_sum = _acc_reward(reward_sum, reward)
        if done:
            print("[Plan] Task finished prematurely after pick rubbish")
            return

        # STEP-7 ─ place(rubbish, bin)
        print("\n[Plan] Step-7 : place rubbish into bin")
        bin_pos = np.array(positions['bin'])
        obs, reward, done = place(
            env, task,
            target_pos=bin_pos,
            approach_distance=0.12,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        reward_sum = _acc_reward(reward_sum, reward)

        # -----------------------------------------------------------------
        # Final report
        # -----------------------------------------------------------------
        if done:
            print(f"[Plan] Task completed! (done=True)  Total reward: {reward_sum:.3f}")
        else:
            print(f"[Plan] Plan finished but environment did not signal done "
                  f"(done=False).  Accumulated reward: {reward_sum:.3f}")

    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()