import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        print("[Init] Found objects and positions:", positions)

        # === Exploration Phase to Identify Missing Predicates ===
        discovered = set()
        for obj_name, obj_pos in positions.items():
            print(f"[Explore] Moving to {obj_name} at {obj_pos}")
            obs, reward, done = move(
                env, task,
                target_pos=obj_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Explore] Episode ended prematurely during move")
                return

            print(f"[Explore] Picking {obj_name} to sense weight/durability")
            obs, reward, done = pick(
                env, task,
                target_pos=obj_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Explore] Episode ended prematurely during pick")
                return

            # Inspect observation for side‐effects
            if hasattr(obs, 'weight_known') and obs.weight_known:
                discovered.add('weight-known')
            if hasattr(obs, 'durability_known') and obs.durability_known:
                discovered.add('durability-known')

            print(f"[Explore] Placing {obj_name} back")
            obs, reward, done = place(
                env, task,
                target_pos=obj_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Explore] Episode ended prematurely during place")
                return

        # Identify drawer handle in the positions dict (if present)
        drawer_handle = None
        for name in positions:
            if 'drawer' in name:
                drawer_handle = name
                break

        if drawer_handle is not None:
            dh_pos = positions[drawer_handle]
            print(f"[Explore] Approaching drawer handle '{drawer_handle}' at {dh_pos}")
            obs, reward, done = move(
                env, task,
                target_pos=dh_pos,
                approach_distance=0.10,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Explore] Episode ended prematurely during drawer approach")
                return

            print(f"[Explore] Pulling drawer handle '{drawer_handle}'")
            obs, reward, done = pull(
                env, task,
                target_pos=dh_pos,
                approach_distance=0.00,
                max_steps=1,
                threshold=0.00,
                timeout=5.0
            )
            if hasattr(obs, 'lock_known') and obs.lock_known:
                discovered.add('lock-known')

        print("=== Exploration Summary ===")
        print("Discovered predicates:", discovered)
        all_expected = {'identified', 'temperature-known', 'weight-known', 'durability-known', 'lock-known'}
        missing = all_expected - discovered
        print("Missing predicates identified through exploration:", missing)

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()