# run_skeleton_task.py (Completed with Exploration Phase and Safety Checks)

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use only predefined skills: move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def is_position_safe(target_pos, object_positions, safety_margin=0.05):
    """
    Check if the target position is safe (not inside any object, not out of bounds).
    Args:
        target_pos (np.ndarray): Target [x, y, z] position.
        object_positions (dict): Dictionary of object names to positions.
        safety_margin (float): Minimum allowed distance to any object.
    Returns:
        bool: True if safe, False otherwise.
    """
    for obj_name, obj_pos in object_positions.items():
        if obj_pos is None:
            continue
        dist = np.linalg.norm(target_pos - np.array(obj_pos))
        if dist < safety_margin:
            print(f"[Safety] Target position too close to {obj_name}: {dist:.3f}m")
            return False
    # Optionally, add workspace bounds check here if known
    return True

def calibrate_gripper_force(env, task):
    """
    Dummy force calibration routine. In a real system, this would run a calibration.
    Here, we just print a message.
    """
    print("[Calibration] Calibrating gripper force... (skipped in simulation)")

def improved_distance_threshold(current_pos, target_pos, orientation_weight=0.1):
    """
    Compute a more robust distance metric that can include orientation if needed.
    For now, just Euclidean distance.
    """
    return np.linalg.norm(current_pos - target_pos)

def explore_for_missing_predicate(env, task, object_positions):
    """
    Exploration phase to determine which predicate is missing for successful task execution.
    This function tries to interact with objects and observes the effects.
    """
    print("===== [Exploration] Starting exploration phase to identify missing predicates =====")
    obs = task.get_observation()
    gripper_pos = obs.gripper_pose[:3]
    # Try to move to each object and pick it, observing what happens
    for obj_name, obj_pos in object_positions.items():
        if obj_pos is None:
            continue
        print(f"[Exploration] Attempting to move to {obj_name} at {obj_pos}")
        if not is_position_safe(np.array(obj_pos), object_positions):
            print(f"[Exploration] Skipping {obj_name} due to unsafe position.")
            continue
        try:
            obs, reward, done = move(env, task, np.array(obj_pos))
            if done:
                print(f"[Exploration] Task ended unexpectedly after move to {obj_name}.")
                return
        except Exception as e:
            print(f"[Exploration] Exception during move to {obj_name}: {e}")
            continue
        # Try to pick the object
        try:
            print(f"[Exploration] Attempting to pick {obj_name}")
            obs, reward, done = pick(env, task, target_pos=np.array(obj_pos))
            if done:
                print(f"[Exploration] Task ended unexpectedly after pick of {obj_name}.")
                return
        except Exception as e:
            print(f"[Exploration] Exception during pick of {obj_name}: {e}")
            continue
        # Try to place the object back (if possible)
        try:
            print(f"[Exploration] Attempting to place {obj_name} at {obj_pos}")
            obs, reward, done = place(env, task, target_pos=np.array(obj_pos))
            if done:
                print(f"[Exploration] Task ended unexpectedly after place of {obj_name}.")
                return
        except Exception as e:
            print(f"[Exploration] Exception during place of {obj_name}: {e}")
            continue
    print("===== [Exploration] Exploration phase complete. Check logs for missing predicate clues. =====")

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()

        # === Force Calibration (if needed) ===
        calibrate_gripper_force(env, task)

        # === Exploration Phase: Identify Missing Predicate ===
        explore_for_missing_predicate(env, task, positions)

        # === Example Task Plan (Replace with actual plan as needed) ===
        # For demonstration, try to pick and place the first available object
        for obj_name, obj_pos in positions.items():
            if obj_pos is None:
                continue
            print(f"[Task] Attempting to move to {obj_name} at {obj_pos}")
            if not is_position_safe(np.array(obj_pos), positions):
                print(f"[Task] Skipping {obj_name} due to unsafe position.")
                continue
            try:
                obs, reward, done = move(env, task, np.array(obj_pos))
                if done:
                    print(f"[Task] Task ended after move to {obj_name}!")
                    return
            except Exception as e:
                print(f"[Task] Exception during move to {obj_name}: {e}")
                continue
            try:
                print(f"[Task] Attempting to pick {obj_name}")
                obs, reward, done = pick(env, task, target_pos=np.array(obj_pos))
                if done:
                    print(f"[Task] Task ended after pick of {obj_name}!")
                    return
            except Exception as e:
                print(f"[Task] Exception during pick of {obj_name}: {e}")
                continue
            # For demonstration, place at the same position
            try:
                print(f"[Task] Attempting to place {obj_name} at {obj_pos}")
                obs, reward, done = place(env, task, target_pos=np.array(obj_pos))
                if done:
                    print(f"[Task] Task ended after place of {obj_name}!")
                    return
            except Exception as e:
                print(f"[Task] Exception during place of {obj_name}: {e}")
                continue
            # Only do for one object in this demo
            break

        # TODO: Insert further plan steps (rotate, pull, etc.) as required by the oracle plan.

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
