from env import setup_environment, shutdown_environment
from skill_code import *
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions
import numpy as np

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    
    # Environment setup
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()
        
        # (Optional) Initialize video writers for recording
        init_video_writers(obs)

        # Wrap the task methods for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # Retrieve object positions and orientations
        positions = get_object_positions()

        # Step 1: Rotate gripper to ninety_deg
        print("[Plan] Step 1: rotate gripper to ninety_deg")
        ninety_quat = positions.get('ninety_deg')
        if ninety_quat is None:
            raise KeyError("ninety_deg orientation not found in positions")

        # Choose drawer to open
        drawer = 'bottom'

        # Step 2: Move to side position of the drawer
        side_key = f"{drawer}_side_pos"
        print(f"[Plan] Step 2: move to side position '{side_key}'")
        if side_key not in positions:
            raise KeyError(f"{side_key} not found in positions")
        side_pos = np.array(positions[side_key])

        # Step 3: Move to anchor position of the drawer
        anchor_key = f"{drawer}_anchor_pos"
        print(f"[Plan] Step 3: move to anchor position '{anchor_key}'")
        if anchor_key not in positions:
            raise KeyError(f"{anchor_key} not found in positions")
        anchor_pos = np.array(positions[anchor_key])

        # ---- Frozen Code Start ----
        obs, reward, done = rotate(env, task, target_quat=np.array(ninety_quat)
        obs, reward, done = move(env, task, target_pos=side_pos)
        obs, reward, done = move(env, task, target_pos=anchor_pos)
        # ---- Frozen Code End ----

        if done:
            print("Task ended during initial drawer approach.")
            return

        # Step 4: Pull the drawer open
        print("[Plan] Step 4: pull drawer open")
        obs, reward, done = pull(env, task, pull_distance=0.1, pull_axis='x')
        if done:
            print("Task ended during pull.")
            return

        # Steps 5-8: Pick each tomato and place on plate
        plate_key = 'plate'
        if plate_key not in positions:
            raise KeyError("plate not found in positions")
        plate_pos = np.array(positions[plate_key])

        step_idx = 5
        for tomato in ['tomato1', 'tomato2']:
            # Pick step
            print(f"[Plan] Step {step_idx}: pick '{tomato}' from table")
            if tomato not in positions:
                raise KeyError(f"{tomato} not found in positions")
            tomato_pos = np.array(positions[tomato])
            obs, reward, done = pick(env, task, target_pos=tomato_pos)
            if done:
                print(f"Task ended during pick of {tomato}.")
                return

            # Place step
            print(f"[Plan] Step {step_idx+1}: place '{tomato}' on plate")
            obs, reward, done = place(env, task, target_pos=plate_pos)
            if done:
                print(f"Task ended during place of {tomato}.")
                return

            step_idx += 2

        # Final status
        print("[Task] All steps executed. Final reward:", reward)
        print("[Task] Done. Goal achieved.")

    finally:
        # Ensure environment is properly shut down
        shutdown_environment(env)

if __name__ == "__main__":
    run_skeleton_task()