from typing import Any, Dict
from gr00t.data.dataset import ModalityConfig
from gr00t.eval.service import BaseInferenceClient, BaseInferenceServer
from gr00t.model.policy import BasePolicy
class RobotInferenceServer(BaseInferenceServer):
    """
    Server with three endpoints for real robot policies
    """
    def __init__(self, model, host: str = "*", port: int = 5555, api_token: str = None):
        super().__init__(host, port, api_token)
        self.register_endpoint("get_action", model.get_action)
        self.register_endpoint(
            "get_modality_config", model.get_modality_config, requires_input=False
        )
    @staticmethod
    def start_server(policy: BasePolicy, port: int, api_token: str = None):
        server = RobotInferenceServer(policy, port=port, api_token=api_token)
        server.run()
class RobotInferenceClient(BaseInferenceClient, BasePolicy):
    """
    Client for communicating with the RealRobotServer
    """
    def __init__(self, host: str = "localhost", port: int = 5555, api_token: str = None):
        super().__init__(host=host, port=port, api_token=api_token)
    def get_action(self, observations: Dict[str, Any]) -> Dict[str, Any]:
        return self.call_endpoint("get_action", observations)
    def get_modality_config(self) -> Dict[str, ModalityConfig]:
        return self.call_endpoint("get_modality_config", requires_input=False)
