import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *
from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        print("[Task] Retrieved object positions:", positions.keys())

        # === Exploration Phase ===
        # Test the rotate skill once to detect if a rotated predicate is missing
        print("[Exploration] Testing rotate skill for missing 'rotated' predicate...")
        identity_quat = np.array([0.0, 0.0, 0.0, 1.0])
        try:
            _obs, _rew, _done = rotate(env, task, identity_quat, max_steps=1, threshold=0.1, timeout=1.0)
            print("[Exploration] rotate call completed without error.")
        except Exception as e:
            print("[Exploration] rotate call raised an error, likely missing predicate 'rotated':", e)
        print("[Exploration] End of exploration phase.\n")

        # === Oracle Plan for Opening a Drawer ===
        # 1) Rotate the gripper to 90 degrees about the Z axis
        print("[Plan] Rotating gripper to 90 degrees about Z axis")
        # quaternion for a 90° rotation around Z: (x,y,z,w) = (0,0, sin(π/4), cos(π/4))
        quat_90z = np.array([0.0, 0.0, np.sin(np.pi/4), np.cos(np.pi/4)])
        try:
            obs, reward, done = rotate(env, task, quat_90z, max_steps=100, threshold=0.02, timeout=10.0)
            if done:
                print("[Plan] Early termination during rotate.")
                return
        except Exception as e:
            print("[Plan] Error during rotate:", e)
            return

        # 2) Move the gripper to the "side" position of the drawer
        # Identify the correct key in positions
        side_key = next((k for k in positions if "side" in k.lower()), None)
        if side_key is None:
            print("[Plan] Could not find a 'side' position in object_positions.")
            return
        side_pos = positions[side_key]
        print(f"[Plan] Moving to side position '{side_key}':", side_pos)
        try:
            obs, reward, done = move(
                env, task,
                target_pos=side_pos,
                approach_distance=0.10,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Plan] Early termination during move-to-side.")
                return
        except Exception as e:
            print("[Plan] Error during move-to-side:", e)
            return

        # 3) Move the gripper from side to the anchor (handle) position
        anchor_key = next((k for k in positions if "anchor" in k.lower() or "handle" in k.lower()), None)
        if anchor_key is None:
            print("[Plan] Could not find an 'anchor' or 'handle' position in object_positions.")
            return
        anchor_pos = positions[anchor_key]
        print(f"[Plan] Moving to anchor position '{anchor_key}':", anchor_pos)
        try:
            obs, reward, done = move(
                env, task,
                target_pos=anchor_pos,
                approach_distance=0.05,
                max_steps=100,
                threshold=0.008,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Plan] Early termination during move-to-anchor.")
                return
        except Exception as e:
            print("[Plan] Error during move-to-anchor:", e)
            return

        # 4) Grasp the drawer handle (pick action)
        print("[Plan] Grasping the drawer handle")
        try:
            obs, reward, done = pick(
                env, task,
                target_pos=anchor_pos,
                approach_distance=0.01,
                max_steps=80,
                threshold=0.005,
                approach_axis='z',
                timeout=8.0
            )
            if done:
                print("[Plan] Early termination during pick-drawer.")
                return
        except Exception as e:
            print("[Plan] Error during pick-drawer:", e)
            return

        # 5) Pull the drawer open
        print("[Plan] Pulling the drawer open")
        try:
            obs, reward, done = pull(env, task, max_steps=60, timeout=5.0)
            if done:
                print("[Plan] Early termination during pull.")
                return
        except Exception as e:
            print("[Plan] Error during pull:", e)
            return

        print("[Plan] Drawer opening sequence complete.")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()