import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # move, pick, place, rotate, pull, etc.

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation, 
    with an exploration phase to identify missing predicates.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        print("[Info] Retrieved positions for objects:", list(positions.keys()))

        # === Exploration Phase ===
        # We expect these predicates from the exploration domain:
        expected_predicates = {
            'identified', 
            'temperature-known', 
            'weight-known', 
            'durability-known', 
            'lock-known'
        }
        discovered = set()

        # 1) Test move to induce 'identified' and 'temperature-known'
        object_keys = list(positions.keys())
        for i in range(len(object_keys)-1):
            src = object_keys[i]
            dst = object_keys[i+1]
            try:
                obs, reward, done = move(env, task, src, dst)
                # we assume the observation object has boolean fields like 
                # identified_<obj> and temperature_known_<obj>
                for obj in object_keys:
                    if getattr(obs, f'identified_{obj}', False):
                        discovered.add('identified')
                    if getattr(obs, f'temperature_known_{obj}', False):
                        discovered.add('temperature-known')
            except Exception as e:
                print(f"[Exploration][move] failed from {src} to {dst}: {e}")

        # 2) Test pick/place to induce 'weight-known' and 'durability-known'
        for obj in object_keys:
            pos = positions[obj]
            try:
                obs, reward, done = pick(env, task, obj, pos)
                # pick has two conditional effects: weight-known or durability-known
                # we check both flags on the observation
                if getattr(obs, 'weight_known', False):
                    discovered.add('weight-known')
                if getattr(obs, 'durability_known', False):
                    discovered.add('durability-known')
                # place object back
                obs, reward, done = place(env, task, obj, pos)
            except Exception as e:
                print(f"[Exploration][pick/place] failed on {obj}: {e}")

        # 3) Test pull on the drawer handle to induce 'lock-known'
        # Replace 'drawer_handle' with the actual key if different
        handle_key = 'drawer_handle'
        if handle_key in positions:
            try:
                hpos = positions[handle_key]
                obs, reward, done = pick(env, task, handle_key, hpos)
                obs, reward, done = pull(env, task, handle_key, hpos)
                if getattr(obs, 'lock_known', False):
                    discovered.add('lock-known')
            except Exception as e:
                print(f"[Exploration][pull] failed on {handle_key}: {e}")
        else:
            print(f"[Exploration] Warning: handle key '{handle_key}' not in positions")

        missing = expected_predicates - discovered
        print("[Exploration] Discovered predicates:", discovered)
        print("[Exploration] Missing predicates:", missing)

        # === (Optional) Oracle Plan Execution ===
        # If we had an oracle plan, we would execute it here step-by-step:
        # plan = [
        #     (rotate, {'target_quat': your_ninety_deg_quat}),
        #     (move,   {'from_loc': 'side_pos',   'to_loc': 'anchor_pos'}),
        #     (pick,   {'obj': 'drawer', 'loc': positions['anchor_pos']}),
        #     (pull,   {'obj': 'drawer', 'loc': positions['anchor_pos']}),
        # ]
        # for func, kwargs in plan:
        #     try:
        #         obs, reward, done = func(env, task, **kwargs)
        #         if done:
        #             print(f"[Plan] Terminated after {func.__name__}")
        #             break
        #     except Exception as e:
        #         print(f"[Plan] {func.__name__} failed with {e}")
        #         break

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()