"""
LIBERO WebSocket Client - 简化版本

支持batch_size * num_trajectory的并行推理。
"""

import logging
import time
import numpy as np
from typing import Dict, Tuple
import websockets.sync.client
import websockets.exceptions

try:
    from . import msgpack_numpy
except ImportError:
    import msgpack_numpy


class LiberoWebsocketClient:
    """LIBERO环境的Websocket客户端。"""
    
    def __init__(self, host: str = "0.0.0.0", port: int = 8003, timeout: int = 150):
        """初始化Websocket客户端。
        
        Args:
            host: 服务器主机地址
            port: 服务器端口号
            timeout: 连接超时时间(秒)
        """
        self.host = host
        self.port = port
        self.timeout = timeout
        self._packer = msgpack_numpy.Packer()
        
        self._ws, self._server_metadata = self._wait_for_server()
        
    def _wait_for_server(self):
        """连接到Websocket服务器。"""
        uri = f"ws://{self.host}:{self.port}"
        
        logging.info(f"Waiting for LIBERO server at {uri}...")
        while True:
            try:
                # Configure client with appropriate timeouts
                conn = websockets.sync.client.connect(
                    uri, 
                    compression=None, 
                    max_size=None,
                    open_timeout=self.timeout,
                    close_timeout=self.timeout,
                    ping_timeout=None,  # Disable ping timeout to avoid keepalive failures
                    ping_interval=None,
                )
                metadata = msgpack_numpy.unpackb(conn.recv())
                logging.info("Successfully connected to LIBERO server")
                return conn, metadata
            except ConnectionRefusedError:
                logging.info("Still waiting for LIBERO server...")
                time.sleep(5)
            
    def _reconnect(self):
        """重新连接到服务器。"""
        logging.warning("Connection lost. Attempting to reconnect...")
        try:
            if self._ws is not None:
                try:
                    self._ws.close()
                except:
                    pass
        except:
            pass
        
        self._ws, self._server_metadata = self._wait_for_server()
        logging.info("Successfully reconnected to LIBERO server")
    
    def _send_recv(self, data: Dict, max_retries: int = 3) -> Dict:
        """发送数据并等待响应，支持自动重连。
        
        Args:
            data: 要发送的数据字典
            max_retries: 最大重试次数
            
        Returns:
        """
        for attempt in range(max_retries):
            try:
                packed_data = self._packer.pack(data)
                self._ws.send(packed_data)
                response_msg = self._ws.recv()
                if isinstance(response_msg, str):
                    raise RuntimeError(f"Error in LIBERO server:\n{response_msg}")
                
                response = msgpack_numpy.unpackb(response_msg)
                return response
            except (websockets.exceptions.ConnectionClosedError, 
                    websockets.exceptions.ConnectionClosed,
                    OSError,
                    BrokenPipeError) as e:
                logging.warning(f"Connection error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    self._reconnect()
                    continue
                else:
                    raise RuntimeError(f"Failed to send/receive after {max_retries} attempts: {e}")
            except Exception as e:
                logging.error(f"Unexpected error in _send_recv: {e}")
                raise
            
    def reset(self, sim_state: np.ndarray, sim_state_len: np.ndarray, task_id: np.ndarray) -> Dict:
        """重置所有环境到相同场景。
        
        Args:
            sim_state: 模拟状态数组
            sim_state_len: 模拟状态长度
            task_id: 任务 ID
        Returns:
        """
        data = {
            "method": "reset",
            "sim_state": sim_state,
            "sim_state_len": sim_state_len,
            "task_id": task_id,
        }
        
        response = self._send_recv(data)
        return response["observation"]
        
    def step(self, actions: np.ndarray) -> Tuple[Dict, np.ndarray, np.ndarray, np.ndarray]:
        """执行单个动作步骤。
        
        Args:
            actions: 动作数组 [num_envs, action_dim]
            
        Returns:
            (observation, rewards, terminations, truncations)的元组
        """
        data = {
            "method": "step",
            "actions": actions,
        }
        
        response = self._send_recv(data)
        
        return (
            response["observation"],
            response["rewards"],
            response["terminations"],
            response["truncations"],
        )
        
    def chunk_step(self, chunk_actions: np.ndarray) -> Dict:
        """执行一系列动作步骤。
        
        Args:
            chunk_actions: 动作数组 [num_envs, chunk_size, action_dim]
            
        Returns:
        """
        data = {
            "method": "chunk_step",
            "chunk_actions": chunk_actions,
        }
        
        response = self._send_recv(data)
        
        return {
            "observation": response["observation"],
            "rewards": response["rewards"],
            "terminations": response["terminations"],
            "truncations": response["truncations"],
        }
        
    def get_server_metadata(self) -> Dict:
        """获取服务器元数据。"""
        return self._server_metadata
        
    def close(self):
        """关闭连接。"""
        if self._ws is not None:
            try:
                data = {"method": "close"}
                packed_data = self._packer.pack(data)
                self._ws.send(packed_data)
                
                response_msg = self._ws.recv()
                
                self._ws.close()
                self._ws = None
                
                logging.info("Closed connection to LIBERO server")
            except Exception as e:
                logging.warning(f"Error during close: {e}")
                if self._ws is not None:
                    self._ws.close()
                    self._ws = None
            
    def __enter__(self):
        """Context manager入口。"""
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager出口。"""
        self.close()
        

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    
    client = LiberoWebsocketClient()
    
    obs = client.reset(sim_state=np.array([0]))
    logging.info(f"Initial observation keys: {obs.keys()}")
    
    num_envs = client.get_server_metadata()["num_envs"]
    random_actions = np.random.randn(num_envs, 7) * 0.1
    obs, rewards, terminations, truncations = client.step(random_actions)
    logging.info(f"Step completed, rewards: {rewards}")
    
    client.close()
