import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset to initial state
        descriptions, obs = task.reset()

        # Initialize video capture if needed
        init_video_writers(obs)

        # Wrap step and get_observation for recording
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        # Example: positions might contain keys like
        # 'object_1', 'drawer_side', 'drawer_anchor', 'safe_drop', etc.

        # === Exploration Phase: Identify Which Knowledge Predicate Is Missing ===
        missing_predicate = None
        for name, pos in positions.items():
            try:
                # Attempt a pick on each object position
                exp_obs, _, done = pick(
                    env,
                    task,
                    target_pos=pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=5.0
                )
                if done:
                    break
                # Inspect observation fields for which predicate was triggered
                state = task.get_observation()
                if hasattr(state, 'weight_known') and state.weight_known:
                    missing_predicate = 'weight-known'
                    break
                if hasattr(state, 'durability_known') and state.durability_known:
                    missing_predicate = 'durability-known'
                    break
                if hasattr(state, 'temperature_known') and state.temperature_known:
                    missing_predicate = 'temperature-known'
                    break
            except Exception:
                # Could not pick this object; move on
                continue

        print(f"[Exploration] Identified missing predicate: {missing_predicate}")

        # === Ensure Gripper Is Empty Before Continuing ===
        current_state = task.get_observation()
        if not getattr(current_state, 'handempty', True):
            # Place whatever is in the gripper at a safe drop location
            safe_drop = positions.get('safe_drop',
                                      next(iter(positions.values())))
            print("[Task] Gripper not empty, placing at safe drop:", safe_drop)
            _, _, done = place(env, task, target_pos=safe_drop)
            if done:
                print("[Task] Episode ended unexpectedly after place.")
                return

        # === Oracle Plan Execution: Open a Drawer ===
        # 1) Rotate gripper from 'zero_deg' to 'ninety_deg'
        print("[Plan] Rotating gripper to ninety_deg")
        _, _, done = rotate(env, task,
                            from_angle='zero_deg',
                            to_angle='ninety_deg')
        if done:
            print("[Plan] Episode ended during rotate.")
            return

        # 2) Move gripper to the side position of the drawer
        side_pos = positions['drawer_side']
        print("[Plan] Moving gripper to drawer side position:", side_pos)
        _, _, done = move(env, task,
                          from_pos=current_state.gripper_pose[:3],
                          to_pos=side_pos)
        if done:
            print("[Plan] Episode ended during move-to-side.")
            return

        # 3) Move gripper from side to anchor position
        anchor_pos = positions['drawer_anchor']
        print("[Plan] Moving gripper to drawer anchor position:", anchor_pos)
        _, _, done = move(env, task,
                          from_pos=side_pos,
                          to_pos=anchor_pos)
        if done:
            print("[Plan] Episode ended during move-to-anchor.")
            return

        # 4) Pick the drawer handle at the anchor position
        print("[Plan] Picking drawer handle")
        _, _, done = pick(
            env,
            task,
            target_pos=anchor_pos,
            approach_distance=0.05,
            max_steps=50,
            threshold=0.005,
            approach_axis='z',
            timeout=5.0
        )
        if done:
            print("[Plan] Episode ended during pick-drawer.")
            return

        # 5) Pull the drawer open
        print("[Plan] Pulling drawer open")
        _, _, done = pull(env, task)
        if done:
            print("[Plan] Episode ended during pull.")
            return

    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()