"""
AdaDS Agent 客户端查询模板
用于从另一个 Python 环境查询 AdaDS Agent 服务器（使用 WebSocket）

支持的接口：
1. predict_k: 选择最佳降采样率
2. state_deviation: 计算 ds=2 的 state deviation

注意：
- 客户端应发送未标准化的状态和动作序列
- 服务器端会自动处理标准化（如果配置了 obs_mean/obs_std 和 action_mean/action_std）
- 使用 WebSocket 可以避免与 HTTP proxy 的冲突
- 使用持久连接类 AdaDSClient 可以避免频繁握手，提高推理速度
"""

import numpy as np
from typing import Optional, Dict, Any
import logging
import time

import websockets.sync.client
import functools
import msgpack


def pack_array(obj):
    """将 NumPy 数组打包为 msgpack 格式"""
    if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
        raise ValueError(f"Unsupported dtype: {obj.dtype}")

    if isinstance(obj, np.ndarray):
        return {
            b"__ndarray__": True,
            b"data": obj.tobytes(),
            b"dtype": obj.dtype.str,
            b"shape": obj.shape,
        }

    if isinstance(obj, np.generic):
        return {
            b"__npgeneric__": True,
            b"data": obj.item(),
            b"dtype": obj.dtype.str,
        }

    return obj


def unpack_array(obj):
    """从 msgpack 格式解包 NumPy 数组"""
    if b"__ndarray__" in obj:
        return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])

    if b"__npgeneric__" in obj:
        return np.dtype(obj[b"dtype"]).type(obj[b"data"])

    return obj


# 创建带有 NumPy 支持的 Packer 和 Unpacker
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)

Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)


class AdaDSClient:
    """AdaDS Agent WebSocket 客户端（持久连接）
    
    建立一次连接，可以多次查询，避免频繁握手。
    
    使用示例：
        # 创建客户端（自动连接）
        client = AdaDSClient(host="127.0.0.1", port=8888)
        
        # 多次查询（复用连接）
        for i in range(100):
            predicted_k = client.query_predict_k(s_env, A)
            deviation = client.query_state_deviation(s_env, A)
        
        # 关闭连接
        client.close()
        
        # 或使用 with 语句自动管理连接
        with AdaDSClient(host="127.0.0.1", port=8888) as client:
            for i in range(100):
                predicted_k = client.query_predict_k(s_env, A)
    """
    
    def __init__(self, host: str = "127.0.0.1", port: int = 8888, timeout: int = 10):
        """初始化客户端并建立连接
        
        Args:
            host: 服务器地址
            port: 端口号
            timeout: 连接超时时间（秒）
        """
        self.host = host
        self.port = port
        self.timeout = timeout
        
        # 构建 URI
        if host.startswith("ws://") or host.startswith("wss://"):
            self._uri = host
        else:
            self._uri = f"ws://{host}:{port}"
        
        # 创建 Packer 实例（复用）
        self._packer = Packer()
        
        # 连接到服务器
        self._conn, self._server_metadata = self._wait_for_server()
        
        logging.info(f"Connected to {self._uri}, metadata: {self._server_metadata}")
    
    def _wait_for_server(self):
        """等待服务器连接（带重试）"""
        logging.info(f"Connecting to server at {self._uri}...")
        while True:
            try:
                conn = websockets.sync.client.connect(
                    self._uri,
                    open_timeout=self.timeout,
                    compression=None,
                    max_size=None
                )
                # 接收服务器元数据
                metadata = unpackb(conn.recv())
                return conn, metadata
            except ConnectionRefusedError:
                logging.info("Server not ready, waiting...")
                time.sleep(5)
            except Exception as e:
                logging.error(f"Connection failed: {e}")
                raise ConnectionError(f"无法连接到服务器 {self._uri}: {e}")
    
    def get_server_metadata(self) -> Dict:
        """获取服务器元数据"""
        return self._server_metadata
    
    def query_predict_k(
        self, 
        s_env: np.ndarray, 
        A: np.ndarray, 
        rgb_obs: Optional[np.ndarray] = None
    ) -> int:
        """查询预测的降采样率 k（复用连接）
        
        Args:
            s_env: 状态数组 [obs_dim] 或 [history_len, obs_dim]，未标准化的原始状态
            A: 动作序列 [T, action_dim]，未归一化的原始动作序列
            rgb_obs: RGB 观测 [num_cameras, H, W, C] 或 [history_len, num_cameras, H, W, C]（可选）
        
        Returns:
            predicted_k: 预测的降采样率
        
        Raises:
            RuntimeError: 如果服务器返回错误
            ValueError: 如果响应格式不正确
        """
        # 发送预测请求
        request = {
            "type": "predict_k",
            "s_env": s_env,
            "A": A
        }
        if rgb_obs is not None:
            request["rgb_obs"] = rgb_obs
        
        self._conn.send(self._packer.pack(request))
        
        # 接收响应
        response = self._conn.recv()
        if isinstance(response, str):
            # 字符串响应表示错误
            raise RuntimeError(f"Server error:\n{response}")
        
        result = unpackb(response)
        
        # 检查响应状态
        if result.get("status") == "error":
            raise ValueError(f"服务器返回错误: {result.get('message', 'Unknown error')}")
        
        if result.get("status") != "success":
            raise ValueError(f"未知的响应状态: {result.get('status')}")
        
        if "predicted_k" not in result:
            raise ValueError("响应中缺少 'predicted_k' 字段")
        
        return int(result["predicted_k"])
    
    def query_state_deviation(
        self,
        s_env: np.ndarray,
        A: np.ndarray,
        rgb_obs: Optional[np.ndarray] = None,
        min_ds: int = 1,
        minimum_decay_steps: int = 2
    ) -> Dict[str, Any]:
        """查询 ds=2 的 state deviation（复用连接）
        
        Args:
            s_env: 状态数组 [obs_dim] 或 [history_len, obs_dim]，未标准化的原始状态
            A: 动作序列 [T, action_dim]，未归一化的原始动作序列
            rgb_obs: RGB 观测 [num_cameras, H, W, C] 或 [history_len, num_cameras, H, W, C]（可选）
            min_ds: 最小降采样率（baseline），默认为1
            minimum_decay_steps: 最小衰减步数，默认为2
        
        Returns:
            结果字典，包含:
                - deviation: float EEF误差（ds=2 vs baseline）
        
        Raises:
            RuntimeError: 如果服务器返回错误
            ValueError: 如果响应格式不正确
        """
        # 发送请求
        request = {
            "type": "state_deviation",
            "s_env": s_env,
            "A": A,
            "min_ds": min_ds,
            "minimum_decay_steps": minimum_decay_steps
        }
        if rgb_obs is not None:
            request["rgb_obs"] = rgb_obs
        
        self._conn.send(self._packer.pack(request))
        
        # 接收响应
        response = self._conn.recv()
        if isinstance(response, str):
            # 字符串响应表示错误
            raise RuntimeError(f"Server error:\n{response}")
        
        result = unpackb(response)
        
        # 检查响应状态
        if result.get("status") == "error":
            raise ValueError(f"服务器返回错误: {result.get('message', 'Unknown error')}")
        
        return {
            "deviation_mean": float(result["deviation_mean"]),
            "deviation_max": float(result["deviation_max"]),
        }
    
    def check_health(self) -> Dict[str, Any]:
        """检查服务器健康状态（复用连接）
        
        Returns:
            健康检查响应字典
        """
        request = {"type": "health"}
        self._conn.send(self._packer.pack(request))
        
        response = self._conn.recv()
        if isinstance(response, str):
            raise RuntimeError(f"Server error:\n{response}")
        
        return unpackb(response)
    
    def close(self):
        """关闭连接"""
        if hasattr(self, '_conn') and self._conn:
            self._conn.close()
            logging.info(f"Connection to {self._uri} closed")
    
    def __enter__(self):
        """支持 with 语句"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """退出 with 语句时自动关闭连接"""
        self.close()
        return False

