from video import init_video_writers, recording_step, recording_get_observation
from env import setup_environment, shutdown_environment
from skill_code import *
from object_positions import get_object_positions

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)
        # wrap recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # retrieve object positions
        positions = get_object_positions()
        tomato1_pos = positions.get('tomato1')
        tomato2_pos = positions.get('tomato2')
        plate_pos = positions.get('plate')
        switch_pos = positions.get('switch')
        if tomato1_pos is None or tomato2_pos is None or plate_pos is None or switch_pos is None:
            raise KeyError("Missing one or more required object positions")

        # bind item1_pos so frozen code picks tomato1
        item1_pos = tomato1_pos

        # Step 1: Close gripper
        print(f"[Plan] Step 1: Closing gripper")
        obs, reward, done = close_gripper(env, task)
        if done:
            print("[Plan] Task ended early after closing gripper")
            return

        # Step 2: Press switch
        print(f"[Plan] Step 2: Pressing switch at position {switch_pos}")
        obs, reward, done = press(
            env,
            task,
            target_pos=switch_pos,
            max_steps=100,
            threshold=0.005,
            timeout=10.0
        )
        if done:
            print("[Plan] Task ended early after pressing switch")
            return

        # Oracle Plan Execution (frozen)
        obs, reward, done = pick(
            env,
            task,
            target_pos=item1_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Plan] Task ended early after pick tomato1")
            return

        obs, reward, done = place(
            env,
            task,
            target_pos=plate_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Plan] Task ended early after place tomato1")
            return

        obs, reward, done = pick(
            env,
            task,
            target_pos=tomato2_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Plan] Task ended early after pick tomato2")
            return

        obs, reward, done = place(
            env,
            task,
            target_pos=plate_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Plan] Task ended early after place tomato2")
            return

        # final report
        print("[Plan] All steps executed. Dropped two tomatoes onto the plate.")
        print(f"[Plan] Final reward: {reward}, done flag: {done}")

    except Exception as e:
        print(f"[Error] Exception during task execution: {e}")
    finally:
        shutdown_environment(env)
    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()