import numpy as np
import re
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import move, pick, place, rotate, pull
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

def run_skeleton_task():
    '''Exploration skeleton for identifying missing predicates via skill failures.'''
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)
        # wrap the step and get_observation for recording
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # Retrieve object positions in the scene
        positions = get_object_positions()  # e.g. {'apple': (x,y,z), 'book': (x,y,z), ...}
        if not positions:
            print("[Error] No objects found in the scene.")
            return

        # Track missing predicates discovered during exploration
        missing_predicates = set()

        # We assume the robot starts at a known home location; track it logically
        current_loc = None

        def extract_predicates_from_error(msg):
            # find all predicate names in error messages like "(not (handempty))" or "(at ?obj ?loc)"
            preds = re.findall(r'\(\s*not\s*\(\s*([\w-]+)', msg)
            preds += re.findall(r'\(\s*([\w-]+)\s+[\?\w-]+\s+[\?\w-]+', msg)
            return [p for p in preds if p]

        # Exploration loop over each object
        for obj_name, pos in positions.items():
            print(f"[Exploration] Processing object '{obj_name}' at position {pos}")
            # 1) Move to the object to identify it and measure temperature
            try:
                obs, reward, done = move(env, task,
                                         target_pos=pos,
                                         max_steps=100,
                                         threshold=0.01,
                                         timeout=5.0)
                if done:
                    print("[Task] Episode ended during move. Exiting.")
                    return
                current_loc = obj_name
                print(f"[Exploration] Moved to '{obj_name}'. Identification and temperature measurement assumed.")
            except Exception as e:
                err = str(e)
                print(f"[Exploration] Move failed for '{obj_name}': {err}")
                for p in extract_predicates_from_error(err):
                    missing_predicates.add(p)
                continue

            # 2) Attempt to pick up the object to learn weight and durability
            try:
                obs, reward, done = pick(env, task,
                                         target_pos=pos,
                                         approach_distance=0.10,
                                         max_steps=80,
                                         threshold=0.01,
                                         approach_axis='z',
                                         timeout=5.0)
                if done:
                    print("[Task] Episode ended during pick. Exiting.")
                    return
                print(f"[Exploration] Picked up '{obj_name}'. Weight and durability measurement assumed.")
            except Exception as e:
                err = str(e)
                print(f"[Exploration] Pick failed for '{obj_name}': {err}")
                for p in extract_predicates_from_error(err):
                    missing_predicates.add(p)
                # if pick failed, try to place any held item back before continuing
                try:
                    obs, _, _ = place(env, task,
                                      target_pos=pos,
                                      approach_distance=0.10,
                                      max_steps=50,
                                      threshold=0.01,
                                      approach_axis='z',
                                      timeout=3.0)
                except:
                    pass
                continue

            # 3) Place the object back to free the gripper
            try:
                obs, reward, done = place(env, task,
                                          target_pos=pos,
                                          approach_distance=0.10,
                                          max_steps=80,
                                          threshold=0.01,
                                          approach_axis='z',
                                          timeout=5.0)
                if done:
                    print("[Task] Episode ended during place. Exiting.")
                    return
                print(f"[Exploration] Placed '{obj_name}' back.")
            except Exception as e:
                err = str(e)
                print(f"[Exploration] Place failed for '{obj_name}': {err}")
                for p in extract_predicates_from_error(err):
                    missing_predicates.add(p)
                # continue to next object anyway

        # 4) If there is any drawer handle in the scene, try pulling it to discover lock-known
        #    Assume positions contains keys like 'drawer_handle'
        for key, pos in positions.items():
            if 'drawer' in key and 'handle' in key:
                print(f"[Exploration] Attempting to pull '{key}' at {pos}")
                try:
                    # approach and grasp logic could be more complex; here we simply call pull
                    obs, reward, done = pull(env, task,
                                             target_pos=pos,
                                             max_steps=80,
                                             threshold=0.01,
                                             timeout=5.0)
                    if done:
                        print("[Task] Episode ended during pull. Exiting.")
                        return
                    print(f"[Exploration] Pulled '{key}'. Lock state measurement assumed.")
                except Exception as e:
                    err = str(e)
                    print(f"[Exploration] Pull failed for '{key}': {err}")
                    for p in extract_predicates_from_error(err):
                        missing_predicates.add(p)

        print("===== Exploration Complete =====")
        if missing_predicates:
            print(f"[Result] Missing predicates identified during exploration: {missing_predicates}")
        else:
            print("[Result] No missing predicates detected; all skill preconditions satisfied.")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()