# ruff: noqa

import contextlib
import dataclasses
import datetime
import faulthandler
import json
import logging
import os
import signal
import sys
import select
import termios
import tty
import time
import threading
import copy
from collections import deque
import concurrent.futures
import cv2
import math
import numpy as np
import torch
from openpi_client import image_tools
from openpi_client import websocket_client_policy
import pandas as pd
from PIL import Image
from torchcubicspline import NaturalCubicSpline, left_clamped_cubic_spline_coeffs
from droid.panda_env import TOPPRAEnv
import tqdm
import tyro
from furniture_bench.async_utils.toppra_server import ToppraServer

from typing import Optional


faulthandler.enable()

logger = logging.getLogger('inference_toppra')
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s.%(msecs)03d [%(name)s] [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
logging.getLogger('inference_toppra').setLevel(logging.INFO)
logging.getLogger('toppra_server').setLevel(logging.INFO)
logging.getLogger('toppra_interpolator').setLevel(logging.DEBUG)
logging.getLogger('panda').setLevel(logging.INFO)

# DROID data collection frequency -- we slow down execution to match this frequency
# DROID_CONTROL_FREQUENCY = 15
DROID_CONTROL_FREQUENCY = 30 # Fast forward
OBS_FREQUENCY = 60


@dataclasses.dataclass
class Args:
    # Hardware parameters
    left_camera_id: str = "20655732"  # e.g., "24259877"
    right_camera_id: str = "20655732"  # e.g., "24514023"
    wrist_camera_id: str = "16787047"  # e.g., "13062452"

    # Policy parameters
    # external_camera: str | None = (
     #   None  # which external camera should be fed to the policy, choose from ["left", "right"]
    #)
    external_camera: str = "left"
    action_space: str = "joint_position"
    
    state_as_action: bool = False
    
    toppra_last_vel = 1.0
    
    gain_scale: float = 1.4
    vel_gain_scale: float = math.sqrt(gain_scale)

    # Rollout parameters
    default_prompt: str = "Put the snack bags in the basket" # z"Put the snacks in the box" # "Put the trashes in the trash bin" # 
    max_duration: float = 120.0
    max_timesteps: int = 1000000
    
    # How many actions to execute from a predicted action chunk before querying policy server again
    # 8 is usually a good default (equals 0.5 seconds of action execution).
    open_loop_horizon: int = 8
    
    num_samples: int = 16    
    spline_curvature_length: int = 8
    spline_curvature_num_points: int = 100
    
    artificial_delay: float = 0.0 # 0.2  for ablation
    
    name: str = f"ours_toppra_last_vel{toppra_last_vel}_gain_{gain_scale}_torque0.99_vel0.7_bon{num_samples}_curve_len{spline_curvature_length}_{DROID_CONTROL_FREQUENCY}hz_30000epoch"
    # name: str = f"ours_ablation_delay{artificial_delay}_toppra_last_vel{toppra_last_vel}_gain_{gain_scale}_torque0.99_vel0.7_bon{num_samples}_curve_len{spline_curvature_length}_{DROID_CONTROL_FREQUENCY}hz_30000epoch"

    # Remote server parameters
    remote_host: str = "192.168.0.136"  # point this to the IP address of the policy server, e.g., "192.168.1.100"
    remote_port: int = (
        8000  # point this to the port of the policy server, default server port for openpi servers is 8000
    )


# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
# waiting for a new action chunk, it will raise an exception and the server connection dies.
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
@contextlib.contextmanager
def prevent_keyboard_interrupt():
    """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
    interrupted = False
    original_handler = signal.getsignal(signal.SIGINT)

    def handler(signum, frame):
        nonlocal interrupted
        interrupted = True

    signal.signal(signal.SIGINT, handler)
    try:
        yield
    finally:
        signal.signal(signal.SIGINT, original_handler)
        if interrupted:
            raise KeyboardInterrupt


def spline_curvature(acs, joint_pos, joint_vel, spline_curvature_length, spline_curvature_num_points):
    if not isinstance(joint_pos, torch.Tensor):
        joint_pos = torch.from_numpy(joint_pos).float()
    device = joint_pos.device

    if not isinstance(acs, torch.Tensor):
        acs = torch.from_numpy(acs).float()
    acs = acs.to(device)

    if not isinstance(joint_vel, torch.Tensor):
        joint_vel = torch.from_numpy(joint_vel).float()
    joint_vel = joint_vel.to(device)

    if acs.ndim == 2:
        acs = acs.unsqueeze(0)

    if joint_pos.ndim == 1:
        joint_pos = joint_pos.unsqueeze(0)

    if joint_vel.ndim == 1:
        joint_vel = joint_vel.unsqueeze(0)

    t = torch.linspace(0, 1, acs.shape[1] + 1, device=device).float()
    waypoints = torch.cat([joint_pos.unsqueeze(1).expand(acs.shape[0], -1, -1), acs[:, :, :-1]], dim=1)
    coeffs = left_clamped_cubic_spline_coeffs(t, waypoints, joint_vel.expand(acs.shape[0], -1))
    spline = NaturalCubicSpline(coeffs)

    s = torch.linspace(0, spline_curvature_length / acs.shape[1], spline_curvature_num_points, device=device).float()

    return -((spline.derivative(s, order=2).norm(dim=2) ** 2).mean(dim=1))


def _select_best_action_chunk(
    fresh_actions_from_policy,
    env,
    num_samples,
    spline_curvature_length,
    spline_curvature_num_points,
):
    is_multi_sample = num_samples > 1 and fresh_actions_from_policy.ndim == 3
    logger.info(f"fresh_actions_from_policy.ndim: {fresh_actions_from_policy.ndim}")

    q0 = env.robot.arm.get_joint_positions()
    joint_vel = env.robot.arm.get_joint_velocities()
    curve_len = min(spline_curvature_length, len(fresh_actions_from_policy))

    if not is_multi_sample:
        joint_trajs_for_scoring = fresh_actions_from_policy[:curve_len]
        scores = spline_curvature(joint_trajs_for_scoring, q0, joint_vel, curve_len, spline_curvature_num_points)
        return fresh_actions_from_policy, scores

    # OpenPI policy server returns (num_samples, horizon, action_dim), but we expect (horizon, num_samples, action_dim)
    if fresh_actions_from_policy.shape[0] == num_samples:
        fresh_actions_from_policy = fresh_actions_from_policy.transpose(1, 0, 2)

    num_samples = fresh_actions_from_policy.shape[1]
    
    joint_trajs_for_scoring = []
    for i in range(num_samples):
        joint_trajs_for_scoring.append(fresh_actions_from_policy[:curve_len, i, :])

    joint_trajs_tensor = torch.from_numpy(np.array(joint_trajs_for_scoring))
    scores = spline_curvature(joint_trajs_tensor, q0, joint_vel, curve_len, spline_curvature_num_points)
    best_idx = torch.argmax(scores).item()

    logger.info(f"Best of {num_samples} scores: {scores.numpy().tolist()}, selected index: {best_idx}")

    best_actions = fresh_actions_from_policy[:, best_idx, :]
    return best_actions, scores


def _obs_fetcher_thread(env, args, shared_data, save_dir):
    """This function runs in a separate thread and periodically fetches observations from the environment."""
    is_first = True
    video_writer = None
    try:
        while not shared_data["stop"]:
            start_time = time.time()

            # Get the current observation
            obs = env.get_observation()
            curr_obs = _extract_observation(
                args,
                obs,
                # Save the first observation to disk
                save_to_disk=is_first,
                save_dir=save_dir,
            )
            image_to_save = curr_obs[f"{args.external_camera}_image"]

            if video_writer is None and image_to_save is not None:
                # Initialize video writer on first valid frame
                height, width, _ = image_to_save.shape
                fourcc = cv2.VideoWriter_fourcc(*"mp4v")
                video_writer = cv2.VideoWriter(
                    shared_data["video_filename"], fourcc, OBS_FREQUENCY, (width, height)
                )

            if video_writer is not None and image_to_save is not None:
                video_writer.write(cv2.cvtColor(image_to_save, cv2.COLOR_RGB2BGR))
                with shared_data["lock"]:
                    shared_data["frames_written"] += 1
            is_first = False

            with shared_data["lock"]:
                shared_data["latest_observation"] = curr_obs

            elapsed_time = time.time() - start_time
            if elapsed_time < 1 / OBS_FREQUENCY:
                time.sleep(1 / OBS_FREQUENCY - elapsed_time)
    finally:
        if video_writer:
            video_writer.release()


def main(args: Args):
    # Make sure external camera is specified by user -- we only use one external camera for the policy
    assert (
        args.external_camera is not None and args.external_camera in ["left", "right"]
    ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"

    # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
    logger.info("Creating the droid env...")
    camera_kwargs = {
        "left_camera_id": args.left_camera_id,
        "right_camera_id": args.right_camera_id,
        "wrist_camera_id": args.wrist_camera_id,
    }
    env = TOPPRAEnv(action_space=args.action_space, control_hz=DROID_CONTROL_FREQUENCY, toppra_last_vel=args.toppra_last_vel, gain_scale=args.gain_scale, vel_gain_scale=args.vel_gain_scale, camera_kwargs=camera_kwargs)
    logger.info("Created the droid env!")

    # Connect to the policy server
    policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

    results = []
    warmed_up = False

    while True:
        instruction = input(f"Enter instruction (press Enter for default: '{args.default_prompt}'): ")
        if not instruction:
            instruction = args.default_prompt
            logger.info(f"Using default instruction: '{instruction}'")
        
        name_input = input(f"Enter name for this rollout (press Enter for default: '{args.name}'): ")
        if name_input:
            args.name = name_input
        
        if not warmed_up:
            # Warmup phase
            logger.info("Starting warmup phase...")
            env.reset()
            obs = env.get_observation()
            curr_obs = _extract_observation(args, obs)
            request_data = {
                "obs": {
                    "observation/exterior_image_1_left": image_tools.resize_with_pad(
                        curr_obs[f"{args.external_camera}_image"], 224, 224
                    ),
                    "observation/wrist_image_left": image_tools.resize_with_pad(
                        curr_obs["wrist_image"], 224, 224
                    ),
                    "observation/joint_position": curr_obs["joint_position"],
                    "observation/gripper_position": curr_obs["gripper_position"],
                    "prompt": instruction,
                },
                "sample_kwargs": {
                    "num_samples": args.num_samples,
                },
            }
            for _ in tqdm.tqdm(range(5), desc="Policy server warmup"):
                policy_client.infer(request_data)
            logger.info("Warmup complete.")
            warmed_up = True

        sane_instruction = "".join(c for c in instruction if c.isalnum() or c in (" ", "_")).rstrip().replace(" ", "_")

        # Reset env to initialize interpolator and robot state
        # env.reset()

        # Start Toppra Server
        toppra_server = ToppraServer(
            interpolator=env.robot.interpolator,
            port=8767,
            inv_dyn=env.robot.get_inv_dyn(),
            toppra_last_vel=env.robot.toppra_last_vel,
        )
        toppra_server_thread = threading.Thread(target=toppra_server.serve_forever, name="ToppraServerThread", daemon=True)
        toppra_server_thread.start()

        start_time = time.time()
        while not toppra_server._is_running:
            time.sleep(0.1)
            if time.time() - start_time > 5:  # 5-second timeout
                raise RuntimeError("ToppraServer failed to start within the timeout period.")
        
        # Start robot's async services (ToppraClient)
        env.robot.start_async_services()
        timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
        
        save_dir = os.path.join("results", sane_instruction, args.name, timestamp)
        os.makedirs(save_dir, exist_ok=True)
        
        args_dict = dataclasses.asdict(args)
        args_dict["DROID_CONTROL_FREQUENCY"] = DROID_CONTROL_FREQUENCY
        with open(os.path.join(save_dir, "args.json"), "w") as f:
            json.dump(args_dict, f, indent=4)
        
        video_filename = os.path.join(save_dir, "video.mp4")

        # This will be shared between main thread and obs fetcher thread
        shared_data = {
            "latest_observation": None,
            "video_filename": video_filename,
            "frames_written": 0,
            "stop": False,
            "lock": threading.Lock(),
        }

        # Start obs fetcher thread
        obs_thread = threading.Thread(target=_obs_fetcher_thread, args=(env, args, shared_data, save_dir))
        obs_thread.start()

        # Wait for the first observation to be available
        while True:
            with shared_data["lock"]:
                if shared_data["latest_observation"] is not None:
                    break
            time.sleep(0.01)

        # ASYNC INFERENCE STATE
        pred_action_chunk = None
        inference_future = None
        inference_start_time: Optional[float] = None
        inference_delays = deque([0.1], maxlen=10)  # Start with a reasonable guess
        inference_start_waypoint = -1
        discard_steps_history = deque([1], maxlen=10)
        current_plan_total_waypoints = 0
        current_plan_start_waypoint = 0
        steps_during_inference = 0
        score_timestamps = []
        spline_curvature_list = []
        normalized_action_consistency_list = []
        current_plan_for_consistency = None
        prev_plan_for_consistency_check = None
        remaining_waypoints_at_execution_start = 0

        # Prepare to save video of rollout
        bar = tqdm.tqdm(range(args.max_timesteps))
        logger.info("Running rollout... press Ctrl+C to stop early, or 'r' to finish gracefully.")
        t_step = 0

        old_settings = termios.tcgetattr(sys.stdin)
        try:
            tty.setcbreak(sys.stdin.fileno())
            rollout_start_time = time.time()
            try:
                for t_step in bar:
                    if time.time() - rollout_start_time > args.max_duration:
                        logger.info(f"Max duration of {args.max_duration} seconds reached. Finishing rollout.")
                        break
                    if select.select([sys.stdin], [], [], 0) == ([sys.stdin], [], []):
                        key = sys.stdin.read(1)
                        if key == "r":
                            break
                        elif key == "t":
                            score_time = time.time() - rollout_start_time
                            score_timestamps.append(score_time)
                            logger.info(f"Score +1 at {score_time:.2f}s. Total score: {len(score_timestamps)}")
                        elif key == "u":
                            if score_timestamps:
                                score_timestamps.pop()
                                logger.info(f"Undo score. Total score: {len(score_timestamps)}")
                            else:
                                logger.info("No score to undo.")
                    start_time = time.time()

                    # 1. CHECK FOR COMPLETED INFERENCE
                    if inference_future is not None and inference_future.done():
                        result = inference_future.result()
                        if inference_start_time is not None:
                            inference_delay = time.time() - inference_start_time
                            if inference_delay < args.artificial_delay:
                                time.sleep(args.artificial_delay - inference_delay)
                                inference_delay = args.artificial_delay
                            inference_delays.append(inference_delay)
                            logger.info(f"Inference delay: {inference_delay:.4f}s. Max delay: {max(inference_delays):.4f}s")
                            inference_start_time = None
                        print(result["actions"].shape)
                        new_action_chunk = result["actions"]

                        original_step_index = result.get("step_index", inference_start_waypoint)
                        if original_step_index != -1:
                            current_waypoints = env.get_total_completed_waypoints()
                            steps_during_inference = current_waypoints - original_step_index
                        else:
                            steps_during_inference = 0  # Fallback
                        
                        if prev_plan_for_consistency_check is not None:
                            best_full_new_chunk, _ = _select_best_action_chunk(
                                new_action_chunk.copy(),
                                env,
                                args.num_samples,
                                args.spline_curvature_length,
                                args.spline_curvature_num_points,
                            )
                            if best_full_new_chunk is not None:
                                len_to_compare = steps_during_inference
                                if len_to_compare > 0 and len_to_compare <= best_full_new_chunk.shape[0] and len_to_compare <= len(prev_plan_for_consistency_check):
                                    consistency = np.linalg.norm(best_full_new_chunk[:len_to_compare] - prev_plan_for_consistency_check[:len_to_compare]) / (len_to_compare * best_full_new_chunk.shape[1])
                                    normalized_action_consistency_list.append(consistency)
                        prev_plan_for_consistency_check = None

                        discard_steps_history.append(steps_during_inference)

                        if steps_during_inference < new_action_chunk.shape[1]:
                            # Replace the current action chunk with the new, fresher one
                            fresh_actions = new_action_chunk[..., steps_during_inference:, :]
                            if args.state_as_action:
                                fresh_actions = np.concatenate([fresh_actions[..., 8:], fresh_actions[..., 7:8]], axis=-1)
                            else:
                                fresh_actions = fresh_actions[..., :8]
                            pred_action_chunk, scores = _select_best_action_chunk(
                                fresh_actions,
                                env,
                                args.num_samples,
                                args.spline_curvature_length,
                                args.spline_curvature_num_points,
                            )
                            if scores is not None:
                                spline_curvature_list.append(scores.max().item())
                            logger.info(f"New action chunk received. Discarded {steps_during_inference} steps.")
                        else:
                            logger.info(f"Stale action chunk received. Discarded all {len(new_action_chunk)} steps.")
                            pred_action_chunk = None

                        inference_future = None
                        
                    # 2. ENSURE WE HAVE AN ACTION TO EXECUTE
                    if pred_action_chunk is None:
                        if inference_future is not None:
                            # Block and wait for the ongoing inference to complete
                            logger.info("Waiting for ongoing inference to complete...")
                            result = inference_future.result()
                            if inference_start_time is not None:
                                inference_delay = time.time() - inference_start_time
                                if inference_delay < args.artificial_delay:
                                    time.sleep(args.artificial_delay - inference_delay)
                                    inference_delay = args.artificial_delay
                                inference_delays.append(inference_delay)
                                logger.info(f"Inference delay: {inference_delay:.4f}s. Max delay: {max(inference_delays):.4f}s")
                                inference_start_time = None
                            print(result["actions"].shape)
                            new_action_chunk = result["actions"]
                            
                            original_step_index = result.get("step_index", inference_start_waypoint)
                            if original_step_index != -1:
                                current_waypoints = env.get_total_completed_waypoints()
                                steps_during_inference = current_waypoints - original_step_index
                            else:
                                steps_during_inference = 0

                            if prev_plan_for_consistency_check is not None:
                                best_full_new_chunk, _ = _select_best_action_chunk(
                                    new_action_chunk.copy(),
                                    env,
                                    args.num_samples,
                                    args.spline_curvature_length,
                                    args.spline_curvature_num_points,
                                )
                                if best_full_new_chunk is not None:
                                    len_to_compare = steps_during_inference
                                    if len_to_compare > 0 and len_to_compare <= best_full_new_chunk.shape[0] and len_to_compare <= len(prev_plan_for_consistency_check):
                                        consistency = np.linalg.norm(best_full_new_chunk[:len_to_compare] - prev_plan_for_consistency_check[:len_to_compare]) / (len_to_compare * best_full_new_chunk.shape[1])
                                        normalized_action_consistency_list.append(consistency)
                            prev_plan_for_consistency_check = None

                            if steps_during_inference < new_action_chunk.shape[1]:
                                fresh_actions = new_action_chunk[..., steps_during_inference:, :]
                                if args.state_as_action:
                                    fresh_actions = np.concatenate(
                                        [fresh_actions[..., 8:], fresh_actions[..., 7:8]], axis=-1
                                    )
                                else:
                                    fresh_actions = fresh_actions[..., :8]
                                pred_action_chunk, scores = _select_best_action_chunk(
                                    fresh_actions,
                                    env,
                                    args.num_samples,
                                    args.spline_curvature_length,
                                    args.spline_curvature_num_points,
                                )
                                if scores is not None:
                                    spline_curvature_list.append(scores.max().item())
                                logger.info(f"New action chunk received. Discarded {steps_during_inference} steps.")
                            else:
                                logger.info(f"Stale action chunk received. Discarded all {len(new_action_chunk)} steps.")
                                pred_action_chunk = None  # It's all stale
                            
                            inference_future = None
                        else:
                            # Blocking inference call for the first step or if a chunk becomes fully stale.
                            logger.info("No actions available, performing blocking inference.")
                            with shared_data["lock"]:
                                curr_obs = copy.deepcopy(shared_data["latest_observation"])
                            if curr_obs is None:
                                time.sleep(0.01)
                                continue

                            request_data = {
                                "obs": {
                                    "observation/exterior_image_1_left": image_tools.resize_with_pad(
                                        curr_obs[f"{args.external_camera}_image"], 224, 224
                                    ),
                                    "observation/wrist_image_left": image_tools.resize_with_pad(
                                        curr_obs["wrist_image"], 224, 224
                                    ),
                                    "observation/joint_position": curr_obs["joint_position"],
                                    "observation/gripper_position": curr_obs["gripper_position"],
                                    "prompt": instruction,
                                },
                                "sample_kwargs": {
                                    "num_samples": args.num_samples,
                                },
                            }
                            with prevent_keyboard_interrupt():
                                result = policy_client.infer(request_data)
                            print(result["actions"].shape)
                            action_chunk = result["actions"]
                            if args.state_as_action:
                                action_chunk = np.concatenate([action_chunk[..., 8:], action_chunk[..., 7:8]], axis=-1)
                            else:
                                action_chunk = action_chunk[..., :8]
                            pred_action_chunk, scores = _select_best_action_chunk(
                                action_chunk,
                                env,
                                args.num_samples,
                                args.spline_curvature_length,
                                args.spline_curvature_num_points,
                            )
                            if scores is not None:
                                spline_curvature_list.append(scores.max().item())

                    if pred_action_chunk is not None:
                        # CONVERT AND EXECUTE CHUNK
                        chunk_to_execute = pred_action_chunk.copy()
                        logger.info(f"current state: {env.get_state()[0]['joint_positions']}")
                        logger.info(f"chunk_to_execute: {chunk_to_execute[-1]}")
                        if chunk_to_execute.shape[1] == 15:
                            chunk_to_execute = np.concatenate([chunk_to_execute[:, 8:], chunk_to_execute[:, 7:8]], axis=1)
                            logger.info(f"final chunk_to_execute: {chunk_to_execute[-1]}")
                        pred_action_chunk = None

                        # Binarize gripper action
                        gripper_actions = chunk_to_execute[:, -1]
                        logger.info(f"gripper_actions: {gripper_actions}")
                        binarized_gripper = np.where(gripper_actions > 0.5, 1.0, 0.0)
                        chunk_to_execute[:, -1] = binarized_gripper

                        # clip all dimensions of action to [-1, 1]
                        if args.action_space == "joint_velocity":
                            chunk_to_execute = np.clip(chunk_to_execute, -1, 1)

                        last_chunk_index = env.get_replan_chunk_index()
                        env.step_joint(chunk_to_execute)
                        
                        current_plan_for_consistency = chunk_to_execute.copy()
                        remaining_waypoints_at_execution_start = env.get_num_remaining_waypoints()
                        remaining_waypoints = remaining_waypoints_at_execution_start

                        # Wait for the new chunk to be acknowledged and processed by the robot
                        start_wait_time = time.time()
                        replanned = True
                        while env.get_replan_chunk_index() <= last_chunk_index:
                            time.sleep(0.01)
                            if time.time() - start_wait_time > 0.05:  # 0.05 second timeout
                                logging.warning(f"Timeout waiting for robot to process new trajectory chunk after index {last_chunk_index}.")
                                replanned = False
                                break
                        logger.info(f"Took {time.time() - start_wait_time} seconds for robot to process new trajectory chunk.")

                    # 3. TRIGGER NEW ASYNC INFERENCE IF NEEDED
                    if inference_future is None:
                        if True:
                            time_threshold = max(inference_delays)
                            remaining_time = env.robot.get_remaining_execution_time()
                            executed_waypoints = remaining_waypoints_at_execution_start - env.get_num_remaining_waypoints()
                            expected_execution_during_inference = env.robot.get_expected_execution_during_inference(max(inference_delays))

                            while remaining_time > time_threshold and executed_waypoints + steps_during_inference + expected_execution_during_inference < args.open_loop_horizon:
                                time.sleep(0.01)  # Wait until it's time to replan
                                remaining_time = env.robot.get_remaining_execution_time()
                                executed_waypoints = remaining_waypoints_at_execution_start - env.get_num_remaining_waypoints()
                                expected_execution_during_inference = env.robot.get_expected_execution_during_inference(max(inference_delays))
                                
                            logger.info(f"Remaining time ({remaining_time:.4f}s) <= threshold ({time_threshold:.4f}s) or open loop horizon ({args.open_loop_horizon} waypoints) reached, starting async inference.")
                        # else:
                        #     remaining_waypoints = env.robot.get_num_remaining_waypoints()
                        #     while 15 - remaining_waypoints + np.mean(discard_steps_history) < args.open_loop_horizon:
                        #         logger.info(f"Remaining waypoints: {remaining_waypoints}, expected execution length: {15 - remaining_waypoints + np.mean(discard_steps_history)} waiting...")
                        #         time.sleep(0.01)
                        #         remaining_waypoints = env.robot.get_num_remaining_waypoints()
                        
                        # remaining_waypoints = env.robot.get_num_remaining_waypoints()
                        # while remaining_waypoints > 0:
                        #     logger.info(f"Remaining waypoints: {remaining_waypoints}, waiting...")
                        #     time.sleep(0.01)
                        #     remaining_waypoints = env.robot.get_num_remaining_waypoints()
                        
                        
                        with shared_data["lock"]:
                            curr_obs = copy.deepcopy(shared_data["latest_observation"])
                        if curr_obs is None:
                            time.sleep(0.01)
                            continue

                        if current_plan_for_consistency is not None:
                            executed_waypoints = remaining_waypoints_at_execution_start - env.get_num_remaining_waypoints()
                            if executed_waypoints < len(current_plan_for_consistency) and executed_waypoints >= 0:
                                prev_plan_for_consistency_check = current_plan_for_consistency[executed_waypoints:]
                            else:
                                prev_plan_for_consistency_check = None
                        else:
                            prev_plan_for_consistency_check = None

                        request_data = {
                            "obs": {
                                "observation/exterior_image_1_left": image_tools.resize_with_pad(
                                    curr_obs[f"{args.external_camera}_image"], 224, 224
                                ),
                                "observation/wrist_image_left": image_tools.resize_with_pad(
                                    curr_obs["wrist_image"], 224, 224
                                ),
                                "observation/joint_position": curr_obs["joint_position"],
                                "observation/gripper_position": curr_obs["gripper_position"],
                                "prompt": instruction,
                            },
                            "sample_kwargs": {
                                "num_samples": args.num_samples,
                            },
                        }
                        inference_start_waypoint = env.get_total_completed_waypoints()
                        request_data["step_index"] = inference_start_waypoint
                        inference_future = executor.submit(policy_client.infer, request_data)
                        inference_start_time = time.time()

                    # Sleep to match DROID data collection frequency
                    elapsed_time = time.time() - start_time

            except KeyboardInterrupt:
                logger.info("\nCaught KeyboardInterrupt, stopping rollout.")
            finally:
                rollout_duration = time.time() - rollout_start_time
                
                # Stop the observation thread
                shared_data["stop"] = True
                obs_thread.join()
                if inference_future is not None:
                    inference_future.cancel()
                
                # Shutdown Toppra and robot services
                toppra_server.stop()
                env.robot.close()
        finally:
            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)

        with shared_data["lock"]:
            frames_written = shared_data["frames_written"]

        if frames_written == 0:
            if input("No video was recorded. Do one more eval? (enter y or n) ").lower() != "y":
                break
            env.reset()
            continue

        scores_filename = os.path.join(save_dir, "scores.json")
        with open(scores_filename, "w") as f:
            json.dump({"score_timestamps": score_timestamps, "total_score": len(score_timestamps)}, f, indent=4)

        _, _, joint_errors = env.robot.get_errors()
        initial_controllable_set_sizes = env.robot.get_initial_controllable_set_sizes()
        controllable_set_sizes = env.robot.get_controllable_set_sizes()
        metrics = {
            "joint_error": np.mean(joint_errors) if joint_errors else 0.0,
            "spline_curvature": np.mean(spline_curvature_list) if spline_curvature_list else 0.0,
            "initial_controllable_set_size": np.mean(initial_controllable_set_sizes) if initial_controllable_set_sizes else 0.0,
            "controllable_set_size": np.mean(controllable_set_sizes) if controllable_set_sizes else 0.0,
            "normalized_action_consistency": np.mean(normalized_action_consistency_list) if normalized_action_consistency_list else 0.0,
            "joint_error_raw": joint_errors,
            "spline_curvature_raw": spline_curvature_list,
            "initial_controllable_set_size_raw": initial_controllable_set_sizes,
            "controllable_set_size_raw": controllable_set_sizes,
            "normalized_action_consistency_raw": normalized_action_consistency_list,
        }
        metrics_filename = os.path.join(save_dir, "metrics.json")
        with open(metrics_filename, "w") as f:
            json.dump(metrics, f, indent=4)

        success = None
        while success is None:
            raw_success = input(
                "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec: "
            )
            if raw_success.lower() == "y":
                success = 1.0
            elif raw_success.lower() == "n":
                success = 0.0
            else:
                try:
                    numeric_success = float(raw_success)
                    if 0 <= numeric_success <= 100:
                        success = numeric_success / 100.0
                    else:
                        logger.info(f"Success must be a number in [0, 100] but got: {numeric_success}")
                except ValueError:
                    logger.info("Invalid input. Please enter 'y', 'n', or a number between 0 and 100.")

        rollout_result = {
            "success": success,
            "steps": t_step,
            "duration_seconds": rollout_duration,
            "video_filename": video_filename,
            "instruction": instruction,
            "name": args.name,
            "timestamp": timestamp,
        }
        results.append(rollout_result)

        df_single = pd.DataFrame([rollout_result])
        csv_filename = os.path.join(save_dir, "result.csv")
        df_single.to_csv(csv_filename, index=False)

        if input("Do one more eval? (enter y or n) ").lower() != "y":
            break
        env.reset()

    executor.shutdown()
    if results:
        df = pd.DataFrame(results)
        summary_dir = os.path.join("results/eval_summary", args.name)
        os.makedirs(summary_dir, exist_ok=True)
        timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
        csv_filename = os.path.join(summary_dir, f"eval_summary_{timestamp}.csv")
        df.to_csv(csv_filename, index=False)
        logger.info(f"Results summary saved to {csv_filename}")


def _extract_observation(args: Args, obs_dict, *, save_to_disk=False, save_dir=None):
    image_observations = obs_dict["image"]
    left_image, right_image, wrist_image = None, None, None
    for key in image_observations:
        # Note the "left" below refers to the left camera in the stereo pair.
        # The model is only trained on left stereo cams, so we only feed those.
        if args.left_camera_id in key and "left" in key:
            left_image = image_observations[key]
        elif args.right_camera_id in key and "left" in key:
            right_image = image_observations[key]
        elif args.wrist_camera_id in key and "left" in key:
            wrist_image = image_observations[key]

    # Drop the alpha dimension
    if left_image is not None:
        left_image = left_image[..., :3]
    if right_image is not None:
        right_image = right_image[..., :3]
    if wrist_image is not None:
        wrist_image = wrist_image[..., :3]

    # Convert to RGB
    if left_image is not None:
        left_image = left_image[..., ::-1]
    if right_image is not None:
        right_image = right_image[..., ::-1]
    if wrist_image is not None:
        wrist_image = wrist_image[..., ::-1]
    
    external_image = left_image if args.external_camera == "left" else right_image

    # In addition to image observations, also capture the proprioceptive state
    robot_state = obs_dict["robot_state"]
    cartesian_position = np.concatenate([robot_state["ee_pos"], robot_state["ee_quat"]])
    joint_position = np.array(robot_state["joint_positions"])
    gripper_position = np.array(robot_state["gripper_width"])

    # Save the images to disk so that they can be viewed live while the robot is running
    # Create one combined image to make live viewing easy
    if save_to_disk:
        combined_image = np.concatenate([wrist_image, external_image], axis=1)
        combined_image = Image.fromarray(combined_image)
        combined_image.save(os.path.join(save_dir, "robot_camera_views.png"))

    return {
        "left_image": left_image,
        "right_image": right_image,
        "wrist_image": wrist_image,
        "cartesian_position": cartesian_position,
        "joint_position": joint_position,
        "gripper_position": gripper_position,
    }


if __name__ == "__main__":
    args: Args = tyro.cli(Args)
    main(args)
