# run_skeleton_task.py (Completed with Exploration and Safety/Calibration Checks)

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use only provided skills

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def is_gripper_empty(task):
    """
    Check if the gripper is empty using the observation.
    Assumes obs.gripper_state or similar is available.
    Returns True if empty, False otherwise.
    """
    obs = task.get_observation()
    # This is a placeholder; adapt to your observation structure.
    # For example, obs['gripper_open'] or obs['holding'].
    # Here, we assume obs.holding is True if holding, False if empty.
    if hasattr(obs, 'holding'):
        return not obs.holding
    if hasattr(obs, 'gripper_open'):
        # If gripper is open, assume empty
        return obs.gripper_open > 0.9
    # Fallback: assume empty (unsafe, but prevents crash)
    return True

def calibrate_gripper_force(env, task, max_attempts=3):
    """
    Dummy force calibration routine.
    In a real system, this would run a calibration sequence.
    Here, we just print and assume calibration is successful.
    """
    print("[Calibration] Calibrating gripper force...")
    for attempt in range(max_attempts):
        # In real code, you would call a calibration skill or check force sensors.
        # Here, we just print and assume success.
        print(f"[Calibration] Attempt {attempt+1}...")
    print("[Calibration] Gripper force calibration complete.")

def check_for_stuck_or_collision(task, prev_pos, curr_pos, stuck_threshold=0.001):
    """
    Check if the gripper is stuck (not moving) or has collided.
    Returns True if stuck, False otherwise.
    """
    dist = np.linalg.norm(np.array(curr_pos) - np.array(prev_pos))
    if dist < stuck_threshold:
        print("[Safety] Gripper may be stuck or collided! Distance moved:", dist)
        return True
    return False

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()

        # === Force Calibration ===
        calibrate_gripper_force(env, task)

        # === Exploration Phase: Identify Missing Predicate ===
        # Try to pick an object and see if any predicate is missing (e.g., weight-known, durability-known)
        # We'll use the first object in positions for exploration
        object_names = list(positions.keys())
        if not object_names:
            print("[Exploration] No objects found in environment!")
            return

        test_obj = object_names[0]
        test_obj_pos = positions[test_obj]
        print(f"[Exploration] Testing with object: {test_obj} at {test_obj_pos}")

        # Safety: Ensure gripper is empty before pick
        if not is_gripper_empty(task):
            print("[Safety] Gripper is not empty before pick! Aborting.")
            return

        # Try to pick the object (this may reveal missing predicates)
        try:
            obs, reward, done = pick(
                env,
                task,
                target_pos=test_obj_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Exploration] Task ended during pick!")
                return
        except Exception as e:
            print(f"[Exploration] Exception during pick: {e}")
            return

        # After pick, check for new predicates (simulate exploration)
        # In a real system, you would check the state for e.g. weight-known, durability-known, etc.
        print("[Exploration] Picked object, checking for missing predicates... (simulate)")

        # === Main Task Plan ===
        # Example: Place the object at a target location (if available)
        # We'll use another object/location as the target, or offset the current position
        if len(object_names) > 1:
            target_obj = object_names[1]
            target_pos = positions[target_obj]
        else:
            # Offset the current position for demonstration
            target_pos = np.array(test_obj_pos) + np.array([0.2, 0, 0])

        # Safety: Check for stuck/collision before place
        prev_pos = obs.gripper_pose[:3]
        # Try to move to place position (simulate collision check)
        try:
            obs, reward, done = place(
                env,
                task,
                target_pos=target_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            curr_pos = obs.gripper_pose[:3]
            if check_for_stuck_or_collision(task, prev_pos, curr_pos):
                print("[Safety] Aborting due to possible collision or stuck gripper.")
                return
            if done:
                print("[Task] Task ended after place!")
                return
        except Exception as e:
            print(f"[Task] Exception during place: {e}")
            return

        print("[Task] Main plan executed successfully.")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
