import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task step and get_observation for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        print("[Info] Retrieved object positions:", positions.keys())

        # Check required objects exist
        required_keys = ['object_1', 'object_2']
        for k in required_keys:
            if k not in positions:
                raise ValueError(f"Expected '{k}' in positions but not found.")

        object_1_pos = positions['object_1']
        object_2_pos = positions['object_2']

        # === Oracle Plan Definition ===
        # We will pick up object_1 and place it at object_2's location.
        oracle_plan = [
            ('pick', dict(
                target_pos=object_1_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )),
            ('place', dict(
                target_pos=object_2_pos,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            ))
        ]

        # === Execute the Oracle Plan ===
        for step_idx, (action_name, params) in enumerate(oracle_plan, start=1):
            print(f"[Plan] Step {step_idx}: Calling skill '{action_name}' with params {params}")
            if action_name not in globals():
                raise ValueError(f"Skill function '{action_name}' is not available.")
            skill_fn = globals()[action_name]
            try:
                obs, reward, done = skill_fn(env, task, **params)
            except Exception as e:
                print(f"[Error] Exception during '{action_name}': {e}")
                raise
            print(f"[Result] '{action_name}' returned reward={reward}, done={done}")
            if done:
                print(f"[Task] Early termination after '{action_name}'")
                return

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()