# run_skeleton_task.py (Completed with Exploration and Safety/Calibration)

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use only predefined skills

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

import time
import traceback

def check_object_list(required_objects, available_objects):
    """Check if all required objects are present in the available object list."""
    missing = set(required_objects) - set(available_objects)
    if missing:
        print(f"[ERROR] Missing objects in environment: {missing}")
        return False
    return True

def calibrate_gripper(env, task):
    """Dummy force calibration step (to be replaced with real calibration if available)."""
    try:
        if hasattr(env, 'gripper') and hasattr(env.gripper, 'calibrate'):
            print("[Calibration] Calibrating gripper force/torque...")
            env.gripper.calibrate()
        else:
            print("[Calibration] Gripper calibration not available in this environment.")
    except Exception as e:
        print(f"[Calibration] Exception during gripper calibration: {e}")

def check_safety(env, task, obj_pos, min_distance=0.05):
    """Check for potential collisions or unsafe proximity before action."""
    try:
        obs = task.get_observation()
        gripper_pos = obs.gripper_pose[:3]
        dist = np.linalg.norm(np.array(gripper_pos) - np.array(obj_pos))
        if dist < min_distance:
            print(f"[Safety] Warning: Gripper is too close to object (distance: {dist:.3f}m).")
            return False
        return True
    except Exception as e:
        print(f"[Safety] Exception during safety check: {e}")
        return False

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        available_objects = list(positions.keys())
        print(f"[Info] Available objects in environment: {available_objects}")

        # === Define Required Objects for the Task ===
        # For demonstration, let's assume we need to manipulate 'drawer', 'object_1'
        # In a real scenario, parse from the oracle plan or observation
        required_objects = ['drawer', 'object_1']
        if not check_object_list(required_objects, available_objects):
            print("[Task] Aborting due to missing objects.")
            return

        # === Force Calibration Step ===
        calibrate_gripper(env, task)

        # === Exploration Phase: Identify Missing Predicate ===
        # Try to pick up each object and see if any predicate is missing (e.g., weight-known, durability-known)
        exploration_timeout = 10.0  # seconds
        exploration_start = time.time()
        for obj in available_objects:
            obj_pos = positions[obj]
            print(f"[Exploration] Attempting to pick object: {obj} at {obj_pos}")
            # Safety check before pick
            if not check_safety(env, task, obj_pos):
                print(f"[Exploration] Skipping {obj} due to safety constraint.")
                continue
            try:
                obs, reward, done = pick(
                    env,
                    task,
                    target_pos=obj_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=5.0
                )
                print(f"[Exploration] Picked {obj}. Checking for new predicates in observation/state...")
                # Here, you would check the state for new predicates (e.g., weight-known, durability-known)
                # For demonstration, we just print a message.
                # In a real system, you might parse the state or logs.
                if done:
                    print(f"[Exploration] Task ended after picking {obj}!")
                    return
            except Exception as e:
                print(f"[Exploration] Exception during pick of {obj}: {e}")
                traceback.print_exc()
            if time.time() - exploration_start > exploration_timeout:
                print("[Exploration] Timeout reached during exploration phase.")
                break

        # === Main Task Plan Execution ===
        # Example: Open a drawer, pick an object, place it somewhere
        # Replace with actual oracle plan steps as needed

        # 1. Move gripper to side of drawer and rotate if needed
        drawer_pos = positions.get('drawer', None)
        if drawer_pos is not None:
            print(f"[Task] Approaching drawer at {drawer_pos}")
            # Safety check
            if not check_safety(env, task, drawer_pos):
                print("[Task] Skipping drawer manipulation due to safety.")
            else:
                try:
                    # Example: Rotate gripper to 90 degrees (quaternion for 90 deg about z)
                    target_quat = [0, 0, np.sin(np.pi/4), np.cos(np.pi/4)]
                    obs, reward, done = rotate(
                        env,
                        task,
                        target_quat=target_quat,
                        max_steps=50,
                        threshold=0.05,
                        timeout=5.0
                    )
                    if done:
                        print("[Task] Task ended during drawer rotation!")
                        return
                except Exception as e:
                    print(f"[Task] Exception during rotate: {e}")
                    traceback.print_exc()
        else:
            print("[Task] Drawer not found in positions.")

        # 2. Pick up object_1 and place it at a target location
        obj1_pos = positions.get('object_1', None)
        if obj1_pos is not None:
            print(f"[Task] Picking object_1 at {obj1_pos}")
            if not check_safety(env, task, obj1_pos):
                print("[Task] Skipping object_1 pick due to safety.")
            else:
                try:
                    obs, reward, done = pick(
                        env,
                        task,
                        target_pos=obj1_pos,
                        approach_distance=0.15,
                        max_steps=100,
                        threshold=0.01,
                        approach_axis='z',
                        timeout=5.0
                    )
                    if done:
                        print("[Task] Task ended after picking object_1!")
                        return
                except Exception as e:
                    print(f"[Task] Exception during pick of object_1: {e}")
                    traceback.print_exc()
        else:
            print("[Task] object_1 not found in positions.")

        # 3. Place object_1 at a predefined location (e.g., 'target_zone')
        target_zone_pos = positions.get('target_zone', None)
        if target_zone_pos is not None:
            print(f"[Task] Placing object_1 at {target_zone_pos}")
            if not check_safety(env, task, target_zone_pos):
                print("[Task] Skipping place due to safety.")
            else:
                try:
                    obs, reward, done = place(
                        env,
                        task,
                        target_pos=target_zone_pos,
                        approach_distance=0.15,
                        max_steps=100,
                        threshold=0.01,
                        approach_axis='z',
                        timeout=5.0
                    )
                    if done:
                        print("[Task] Task ended after placing object_1!")
                        return
                except Exception as e:
                    print(f"[Task] Exception during place of object_1: {e}")
                    traceback.print_exc()
        else:
            print("[Task] target_zone not found in positions.")

        print("[Task] Main plan execution completed.")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()