import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # Retrieve object and drawer-handle positions
        positions = get_object_positions()
        # Separate object positions from drawer handle positions
        object_positions = {name: pos for name, pos in positions.items() if 'object' in name}
        drawer_handles = [name for name in positions.keys() if 'drawer' in name and 'handle' in name]

        # Exploration knowledge dictionary
        exploration_knowledge = {}

        # 1) Explore each object: move, pick (to reveal weight or durability), place back
        for obj_name, obj_pos in object_positions.items():
            print(f"[Exploration] Approaching {obj_name} at {obj_pos}")
            obs, _, done = move(env, task,
                                target_pos=obj_pos,
                                approach_distance=0.15,
                                max_steps=100,
                                threshold=0.01,
                                timeout=5.0)
            if done:
                print("[Task] Ended during move to object.")
                return

            print(f"[Exploration] Picking up {obj_name} to identify properties")
            obs, _, done = pick(env, task,
                                target_pos=obj_pos,
                                approach_distance=0.10,
                                max_steps=100,
                                threshold=0.01,
                                timeout=5.0)
            if done:
                print("[Task] Ended during pick exploration.")
                return

            # Check what property was learned
            if getattr(obs, 'weight_known', False):
                exploration_knowledge[obj_name] = 'weight-known'
            elif getattr(obs, 'durability_known', False):
                exploration_knowledge[obj_name] = 'durability-known'
            else:
                exploration_knowledge[obj_name] = 'unknown'

            print(f"[Exploration] Learned for {obj_name}: {exploration_knowledge[obj_name]}")

            # Place the object back
            obs, _, done = place(env, task,
                                 target_pos=obj_pos,
                                 approach_distance=0.10,
                                 max_steps=100,
                                 threshold=0.01,
                                 timeout=5.0)
            if done:
                print("[Task] Ended during place exploration.")
                return

        # 2) Explore each drawer handle: move, pull (to reveal lock status)
        for handle_name in drawer_handles:
            handle_pos = positions[handle_name]
            print(f"[Exploration] Approaching drawer handle {handle_name} at {handle_pos}")
            obs, _, done = move(env, task,
                                target_pos=handle_pos,
                                approach_distance=0.15,
                                max_steps=100,
                                threshold=0.01,
                                timeout=5.0)
            if done:
                print("[Task] Ended during move to drawer handle.")
                return

            print(f"[Exploration] Pulling on {handle_name} to learn lock state")
            obs, _, done = pull(env, task)
            if done:
                print("[Task] Ended during pull exploration.")
                return

            if getattr(obs, 'lock_known', False):
                exploration_knowledge[handle_name] = 'lock-known'
            else:
                exploration_knowledge[handle_name] = 'lock-unknown'

            print(f"[Exploration] Learned for {handle_name}: {exploration_knowledge[handle_name]}")

        print("[Exploration] Summary of learned predicates:", exploration_knowledge)

        # 3) Find an unlocked drawer and open it
        for handle_name, status in exploration_knowledge.items():
            if status == 'lock-unknown':
                drawer_handle_pos = positions[handle_name]
                print(f"[Task] Opening drawer via handle {handle_name}")

                # Rotate gripper to the required orientation (e.g., ninety_deg)
                target_quat = [0, 0, 0, 1]  # placeholder quaternion for ninety_deg
                obs, _, done = rotate(env, task,
                                      target_quat=target_quat,
                                      max_steps=100,
                                      threshold=0.05,
                                      timeout=10.0)
                if done:
                    print("[Task] Ended during rotate to open drawer.")
                    return

                # Approach the handle
                obs, _, done = move(env, task,
                                    target_pos=drawer_handle_pos,
                                    approach_distance=0.15,
                                    max_steps=100,
                                    threshold=0.01,
                                    timeout=5.0)
                if done:
                    print("[Task] Ended during approach to handle.")
                    return

                # Pull to open the drawer
                obs, _, done = pull(env, task)
                if done:
                    print("[Task] Ended during final pull.")
                    return

                print("[Task] Drawer opened successfully.")
                break

    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()