# run_skeleton_task.py (Completed with Exploration Phase and Predicate Discovery)

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use only predefined skills: pick, place, move, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation, with exploration for missing predicates.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()

        # === Exploration Phase: Discover Missing Predicate ===
        # The goal is to determine which predicate is missing by interacting with the environment.
        # We will attempt to use the available skills to interact with objects and locations,
        # and observe if any expected information (e.g., object identification, temperature, lock status) is missing.

        # Step 1: List all objects and locations from the environment
        # (Assume positions is a dict: {object_name: position_tuple, ...})
        object_names = list(positions.keys())
        print(f"[Exploration] Detected objects: {object_names}")

        # Step 2: Try to pick each object and see if any information is missing
        for obj_name in object_names:
            obj_pos = positions[obj_name]
            print(f"[Exploration] Attempting to pick object: {obj_name} at position {obj_pos}")

            try:
                # Try to pick the object (simulate exploration for missing predicates)
                obs, reward, done = pick(
                    env,
                    task,
                    target_pos=obj_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=10.0
                )
                print(f"[Exploration] Picked {obj_name}. Checking for missing information...")

                # After picking, check if any expected information is missing
                # For example, check if 'weight-known', 'durability-known', or 'identified' is available
                # Since we don't have direct access to predicates, we log the observation for analysis
                obs_after_pick = task.get_observation()
                print(f"[Exploration] Observation after picking {obj_name}: {obs_after_pick}")

                # Place the object back (if possible)
                obs, reward, done = place(
                    env,
                    task,
                    target_pos=obj_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=10.0
                )
                print(f"[Exploration] Placed {obj_name} back.")

                if done:
                    print("[Exploration] Task ended during exploration!")
                    return

            except Exception as e:
                print(f"[Exploration] Exception during pick/place of {obj_name}: {e}")

        # Step 3: Try to move to each object's location (simulate 'move' exploration)
        for obj_name in object_names:
            obj_pos = positions[obj_name]
            print(f"[Exploration] Attempting to move to object: {obj_name} at position {obj_pos}")

            try:
                obs, reward, done = move(
                    env,
                    task,
                    target_pos=obj_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=10.0
                )
                print(f"[Exploration] Moved to {obj_name}. Checking for missing information...")

                obs_after_move = task.get_observation()
                print(f"[Exploration] Observation after moving to {obj_name}: {obs_after_move}")

                if done:
                    print("[Exploration] Task ended during exploration!")
                    return

            except Exception as e:
                print(f"[Exploration] Exception during move to {obj_name}: {e}")

        # Step 4: Try to pull or rotate if applicable (simulate exploration for lock-known or other predicates)
        for obj_name in object_names:
            if "drawer" in obj_name or "handle" in obj_name:
                obj_pos = positions[obj_name]
                print(f"[Exploration] Attempting to pull/rotate {obj_name} at position {obj_pos}")

                try:
                    # Try to rotate (if skill is available)
                    # For demonstration, use a dummy quaternion (identity)
                    target_quat = np.array([0, 0, 0, 1])
                    obs, reward, done = rotate(
                        env,
                        task,
                        target_quat=target_quat,
                        max_steps=100,
                        threshold=0.05,
                        timeout=10.0
                    )
                    print(f"[Exploration] Rotated {obj_name}. Checking for missing information...")

                    # Try to pull
                    obs, reward, done = pull(
                        env,
                        task,
                        target_pos=obj_pos,
                        approach_distance=0.15,
                        max_steps=100,
                        threshold=0.01,
                        approach_axis='z',
                        timeout=10.0
                    )
                    print(f"[Exploration] Pulled {obj_name}. Checking for missing information...")

                    obs_after_pull = task.get_observation()
                    print(f"[Exploration] Observation after pulling {obj_name}: {obs_after_pull}")

                    if done:
                        print("[Exploration] Task ended during exploration!")
                        return

                except Exception as e:
                    print(f"[Exploration] Exception during rotate/pull of {obj_name}: {e}")

        # === End of Exploration Phase ===
        print("[Exploration] Exploration phase complete. Analyze logs to determine missing predicate.")

        # === (Optional) Main Task Plan Execution ===
        # After exploration, you would normally execute the oracle plan using the discovered predicates.
        # For this code, we focus on the exploration and predicate discovery as per the feedback.

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
