"""
GR00T Inference Service
This script provides both ZMQ and HTTP server/client implementations for deploying GR00T models.
The HTTP server exposes a REST API for easy integration with web applications and other services.
1. Default is zmq server.
Run server: python scripts/inference_service.py --server
Run client: python scripts/inference_service.py --client
2. Run as Http Server:
Dependencies for `http_server` mode:
    => Server (runs GR00T model on GPU): `pip install uvicorn fastapi json-numpy`
    => Client: `pip install requests json-numpy`
HTTP Server Usage:
    python scripts/inference_service.py --server --http-server --port 8000
HTTP Client Usage (assuming a server running on 0.0.0.0:8000):
    python scripts/inference_service.py --client --http-server --host 0.0.0.0 --port 8000
You can use bore to forward the port to your client: `159.223.171.199` is bore.pub.
    bore local 8000 --to 159.223.171.199
"""
import time
from dataclasses import dataclass
from typing import Literal
import numpy as np
import tyro
from gr00t.data.embodiment_tags import EMBODIMENT_TAG_MAPPING
from gr00t.eval.robot import RobotInferenceClient, RobotInferenceServer
from gr00t.experiment.data_config import load_data_config
from gr00t.model.policy import Gr00tPolicy
@dataclass
class ArgsConfig:
    """Command line arguments for the inference service."""
    model_path: str = "nvidia/GR00T-N1.5-3B"
    """Path to the model checkpoint directory."""
    embodiment_tag: Literal[tuple(EMBODIMENT_TAG_MAPPING.keys())] = "gr1"
    """The embodiment tag for the model."""
    data_config: str = "fourier_gr1_arms_waist"
    """
    The name of the data config to use, e.g. so100, fourier_gr1_arms_only, unitree_g1, etc.
    Or a path to a custom data config file. e.g. "module:ClassName" format.
    See gr00t/experiment/data_config.py for more details.
    """
    port: int = 5555
    """The port number for the server."""
    host: str = "localhost"
    """The host address for the server."""
    server: bool = False
    """Whether to run the server."""
    client: bool = False
    """Whether to run the client."""
    denoising_steps: int = 4
    """The number of denoising steps to use."""
    api_token: str = None
    """API token for authentication. If not provided, authentication is disabled."""
    http_server: bool = False
    """Whether to run it as HTTP server. Default is ZMQ server."""
def _example_zmq_client_call(obs: dict, host: str, port: int, api_token: str):
    """
    Example ZMQ client call to the server.
    """
    policy_client = RobotInferenceClient(host=host, port=port, api_token=api_token)
    print("Available modality config available:")
    modality_configs = policy_client.get_modality_config()
    print(modality_configs.keys())
    time_start = time.time()
    action = policy_client.get_action(obs)
    print(f"Total time taken to get action from server: {time.time() - time_start} seconds")
    return action
def _example_http_client_call(obs: dict, host: str, port: int, api_token: str):
    """
    Example HTTP client call to the server.
    """
    import json_numpy
    json_numpy.patch()
    import requests
    print("Testing HTTP server...")
    time_start = time.time()
    response = requests.post(f"http://{host}:{port}/act", json={"observation": obs})
    print(f"Total time taken to get action from HTTP server: {time.time() - time_start} seconds")
    if response.status_code == 200:
        action = response.json()
        return action
    else:
        print(f"Error: {response.status_code} - {response.text}")
        return {}
def main(args: ArgsConfig):
    if args.server:
        data_config = load_data_config(args.data_config)
        modality_config = data_config.modality_config()
        modality_transform = data_config.transform()
        policy = Gr00tPolicy(
            model_path=args.model_path,
            modality_config=modality_config,
            modality_transform=modality_transform,
            embodiment_tag=args.embodiment_tag,
            denoising_steps=args.denoising_steps,
        )
        if args.http_server:
            from gr00t.eval.http_server import HTTPInferenceServer
            server = HTTPInferenceServer(
                policy, port=args.port, host=args.host, api_token=args.api_token
            )
            server.run()
        else:
            server = RobotInferenceServer(policy, port=args.port, api_token=args.api_token)
            server.run()
    elif args.client:
        obs = {
            "video.ego_view": np.random.randint(0, 256, (1, 256, 256, 3), dtype=np.uint8),
            "state.left_arm": np.random.rand(1, 7),
            "state.right_arm": np.random.rand(1, 7),
            "state.left_hand": np.random.rand(1, 6),
            "state.right_hand": np.random.rand(1, 6),
            "state.waist": np.random.rand(1, 3),
            "annotation.human.action.task_description": ["do your thing!"],
        }
        if args.http_server:
            action = _example_http_client_call(obs, args.host, args.port, args.api_token)
        else:
            action = _example_zmq_client_call(obs, args.host, args.port, args.api_token)
        for key, value in action.items():
            print(f"Action: {key}: {value.shape}")
    else:
        raise ValueError("Please specify either --server or --client")
if __name__ == "__main__":
    config = tyro.cli(ArgsConfig)
    main(config)
