# run_skeleton_task.py (Completed with Exploration and Robustness)

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use only predefined skills: pick, place, move, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        # Example usage: positions = {'object_1': (0,0,0), 'object_2': (1,1,1), ...}
        positions = get_object_positions()

        # === Exploration Phase: Identify Missing Predicate/Objects ===
        # The feedback indicates we must check for object existence before acting.
        # We'll explore the environment to list all objects and their states.

        # Try to get the list of objects from the observation, fallback to positions
        object_list = []
        try:
            if hasattr(obs, 'object_list'):
                object_list = list(obs.object_list)
            elif isinstance(positions, dict):
                object_list = list(positions.keys())
            else:
                print("[Exploration] Warning: Could not determine object list from observation or positions.")
        except Exception as e:
            print(f"[Exploration] Exception while retrieving object list: {e}")

        print(f"[Exploration] Detected objects in environment: {object_list}")

        # For demonstration, let's check for a few key objects that might be involved in the plan
        required_objects = ['drawer', 'plate', 'tomato', 'top']  # Example names; adjust as needed
        for obj in required_objects:
            if obj not in object_list:
                print(f"[Exploration] Warning: Required object '{obj}' not found in environment.")

        # === Example: Robust Action Execution with Checks ===

        # Suppose the plan is: rotate a drawer, pull it, pick a tomato, place it on a plate
        # We'll check for each object before acting, as per feedback

        # 1. Rotate the drawer (if present)
        drawer_name = 'drawer'
        if drawer_name in object_list:
            # Determine target quaternion for rotation (example: 90 degrees about z)
            target_quat = np.array([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)])  # xyzw
            try:
                print(f"[Task] Rotating {drawer_name}...")
                obs, reward, done = rotate(env, task, target_quat, threshold=0.05, timeout=10.0)
                if done:
                    print("[Task] Task ended during rotate!")
                    return
            except Exception as e:
                print(f"[Task] Exception during rotate: {e}")
        else:
            print(f"[Task] Error: {drawer_name} not found in object list. Skipping rotate.")

        # 2. Pull the drawer (if present and not locked)
        # Let's check for a 'top' drawer as well, as per feedback
        top_drawer = 'top'
        if top_drawer in object_list:
            # Check if the drawer is locked (simulate with obs or positions if possible)
            is_locked = False
            try:
                if hasattr(obs, 'is_locked'):
                    is_locked = obs.is_locked.get(top_drawer, False)
                elif hasattr(obs, 'locked_drawers'):
                    is_locked = top_drawer in obs.locked_drawers
            except Exception:
                pass
            if is_locked:
                print(f"[Task] Error: Drawer '{top_drawer}' is locked, cannot pull.")
            else:
                try:
                    print(f"[Task] Pulling {top_drawer}...")
                    obs, reward, done = pull(env, task, top_drawer)
                    if done:
                        print("[Task] Task ended during pull!")
                        return
                except Exception as e:
                    print(f"[Task] Exception during pull: {e}")
        else:
            print(f"[Task] Error: Drawer '{top_drawer}' not found in object list. Skipping pull.")

        # 3. Pick up a tomato (if present)
        tomato_name = 'tomato'
        if tomato_name in object_list:
            tomato_pos = positions.get(tomato_name, None)
            if tomato_pos is not None:
                try:
                    print(f"[Task] Picking up {tomato_name} at {tomato_pos}...")
                    obs, reward, done = pick(env, task, target_pos=tomato_pos, approach_distance=0.15, max_steps=100, threshold=0.01, approach_axis='z', timeout=10.0)
                    if done:
                        print("[Task] Task ended during pick!")
                        return
                except Exception as e:
                    print(f"[Task] Exception during pick: {e}")
            else:
                print(f"[Task] Error: Position for '{tomato_name}' not found.")
        else:
            print(f"[Task] Error: '{tomato_name}' not found in object list. Skipping pick.")

        # 4. Place tomato on plate (if plate exists)
        plate_name = 'plate'
        if plate_name in object_list:
            plate_pos = positions.get(plate_name, None)
            if plate_pos is not None:
                try:
                    print(f"[Task] Placing {tomato_name} on {plate_name} at {plate_pos}...")
                    obs, reward, done = place(env, task, target_pos=plate_pos, approach_distance=0.15, max_steps=100, threshold=0.01, approach_axis='z', timeout=10.0)
                    if done:
                        print("[Task] Task ended during place!")
                        return
                except Exception as e:
                    print(f"[Task] Exception during place: {e}")
            else:
                print(f"[Task] Error: Position for '{plate_name}' not found.")
        else:
            print(f"[Task] Error: '{plate_name}' not found in object list. Skipping place.")

        # === Dynamic Threshold for Rotation (as per feedback) ===
        # If you want to use a dynamic threshold for rotation:
        # threshold = 0.05 * np.linalg.norm(target_quat)
        # (Already passed as argument above)

        # === Error Handling for Rotation Not Reaching Target ===
        # Already handled in the rotate skill (prints warning and breaks if timeout)

        # === End of Plan ===

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()