"""
Safe-IQL Agent 客户端查询模板
用于从另一个 Python 环境查询 Safe-IQL Agent 服务器（使用 WebSocket）

注意：
- 客户端应发送未标准化的状态和动作序列
- 服务器端会自动处理标准化（如果配置了 obs_mean/obs_std 和 action_mean/action_std）
- 动作序列应该是未归一化的 delta actions（如果使用 delta 模式）
- 使用 WebSocket 可以避免与 HTTP proxy 的冲突
- 使用持久连接类 SafeIQLClient 可以避免频繁握手，提高推理速度
"""

import numpy as np
from typing import Optional, Dict, Any
import logging
import time

import websockets.sync.client
import functools
import msgpack


def pack_array(obj):
    """将 NumPy 数组打包为 msgpack 格式"""
    if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
        raise ValueError(f"Unsupported dtype: {obj.dtype}")

    if isinstance(obj, np.ndarray):
        return {
            b"__ndarray__": True,
            b"data": obj.tobytes(),
            b"dtype": obj.dtype.str,
            b"shape": obj.shape,
        }

    if isinstance(obj, np.generic):
        return {
            b"__npgeneric__": True,
            b"data": obj.item(),
            b"dtype": obj.dtype.str,
        }

    return obj


def unpack_array(obj):
    """从 msgpack 格式解包 NumPy 数组"""
    if b"__ndarray__" in obj:
        return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])

    if b"__npgeneric__" in obj:
        return np.dtype(obj[b"dtype"]).type(obj[b"data"])

    return obj


# 创建带有 NumPy 支持的 Packer 和 Unpacker
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)

Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)


class SafeIQLClient:
    """Safe-IQL Agent WebSocket 客户端（持久连接）
    
    参照 openpi 的 WebsocketClientPolicy 实现。
    建立一次连接，可以多次查询，避免频繁握手。
    
    使用示例：
        # 创建客户端（自动连接）
        client = SafeIQLClient(host="127.0.0.1", port=8888)
        
        # 多次查询（复用连接）
        for i in range(100):
            predicted_k = client.query_predict_k(s_env, A)
        
        # 关闭连接
        client.close()
        
        # 或使用 with 语句自动管理连接
        with SafeIQLClient(host="127.0.0.1", port=8888) as client:
            for i in range(100):
                predicted_k = client.query_predict_k(s_env, A)
    """
    
    def __init__(self, host: str = "127.0.0.1", port: int = 8888, timeout: int = 10):
        """初始化客户端并建立连接
        
        Args:
            host: 服务器地址
            port: 端口号
            timeout: 连接超时时间（秒）
        """
        self.host = host
        self.port = port
        self.timeout = timeout
        
        # 构建 URI
        if host.startswith("ws://") or host.startswith("wss://"):
            self._uri = host
        else:
            self._uri = f"ws://{host}:{port}"
        
        # 创建 Packer 实例（复用）
        self._packer = Packer()
        
        # 连接到服务器
        self._conn, self._server_metadata = self._wait_for_server()
        
        logging.info(f"Connected to {self._uri}, metadata: {self._server_metadata}")
    
    def _wait_for_server(self):
        """等待服务器连接（带重试）"""
        logging.info(f"Connecting to server at {self._uri}...")
        while True:
            try:
                conn = websockets.sync.client.connect(
                    self._uri,
                    open_timeout=self.timeout,
                    compression=None,
                    max_size=None
                )
                # 接收服务器元数据
                metadata = unpackb(conn.recv())
                return conn, metadata
            except ConnectionRefusedError:
                logging.info("Server not ready, waiting...")
                time.sleep(5)
            except Exception as e:
                logging.error(f"Connection failed: {e}")
                raise ConnectionError(f"无法连接到服务器 {self._uri}: {e}")
    
    def get_server_metadata(self) -> Dict:
        """获取服务器元数据"""
        return self._server_metadata
    
    def query_predict_k(self, s_env: np.ndarray, A: np.ndarray) -> int:
        """查询预测的 k 值（复用连接）
        
        Args:
            s_env: 状态数组 [obs_dim]，未标准化的原始状态
            A: 动作序列 [T, action_dim]，未归一化的原始动作序列
        
        Returns:
            predicted_k: 预测的 k 值
        
        Raises:
            RuntimeError: 如果服务器返回错误
            ValueError: 如果响应格式不正确
        """
        # 发送预测请求
        request = {
            "type": "predict_k",
            "s_env": s_env,
            "A": A
        }
        self._conn.send(self._packer.pack(request))
        
        # 接收响应
        response = self._conn.recv()
        if isinstance(response, str):
            # 字符串响应表示错误
            raise RuntimeError(f"Server error:\n{response}")
        
        result = unpackb(response)
        
        # 检查响应状态
        if result.get("status") == "error":
            raise ValueError(f"服务器返回错误: {result.get('message', 'Unknown error')}")
        
        if result.get("status") != "success":
            raise ValueError(f"未知的响应状态: {result.get('status')}")
        
        if "predicted_k" not in result:
            raise ValueError("响应中缺少 'predicted_k' 字段")
        
        return int(result["predicted_k"])
    
    def check_health(self) -> Dict[str, Any]:
        """检查服务器健康状态（复用连接）
        
        Returns:
            健康检查响应字典
        """
        request = {"type": "health"}
        self._conn.send(self._packer.pack(request))
        
        response = self._conn.recv()
        if isinstance(response, str):
            raise RuntimeError(f"Server error:\n{response}")
        
        return unpackb(response)
    
    def close(self):
        """关闭连接"""
        if hasattr(self, '_conn') and self._conn:
            self._conn.close()
            logging.info(f"Connection to {self._uri} closed")
    
    def __enter__(self):
        """支持 with 语句"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """退出 with 语句时自动关闭连接"""
        self.close()
        return False


# ==================== 兼容性函数（一次性连接）====================
# 以下函数保留用于简单的一次性查询场景


def query_predict_k(s_env: np.ndarray, A: np.ndarray, host: str, port: int, timeout: int = 10) -> int:
    """
    向指定的主机和端口发起 predict_k 请求（一次性连接）
    
    注意：此函数每次调用都会建立新连接，适用于偶尔查询的场景。
    如果需要频繁查询（如 Libero 仿真），请使用 SafeIQLClient 类以复用连接。
    
    Args:
        s_env: 状态数组 [obs_dim]，未标准化的原始状态
        A: 动作序列 [T, action_dim]，未归一化的原始动作序列（delta actions）
        host: 服务器地址 (例如 '127.0.0.1' 或 'my_server_ip')
        port: 端口号 (例如 8888)
        timeout: 连接超时时间（秒），默认为 10
    
    Returns:
        predicted_k: 预测的 k 值
    
    Raises:
        ConnectionError: 如果无法连接到服务器
        ValueError: 如果响应格式不正确
    
    推荐用法（频繁查询）：
        client = SafeIQLClient(host, port)
        for i in range(100):
            predicted_k = client.query_predict_k(s_env, A)
        client.close()
    """
    # 使用临时客户端
    with SafeIQLClient(host, port, timeout) as client:
        return client.query_predict_k(s_env, A)


def check_health(host: str, port: int, timeout: int = 5) -> Dict[str, Any]:
    """
    检查服务器健康状态（一次性连接）
    
    注意：此函数每次调用都会建立新连接。
    如果需要频繁查询，请使用 SafeIQLClient 类。
    
    Args:
        host: 服务器地址
        port: 端口号
        timeout: 连接超时时间（秒），默认为 5
    
    Returns:
        健康检查响应字典，包含 "status" 和其他服务器元数据
    
    Raises:
        ConnectionError: 如果无法连接到服务器
    """
    # 使用临时客户端
    with SafeIQLClient(host, port, timeout) as client:
        return client.check_health()


if __name__ == "__main__":
    # 示例 1：使用持久连接（推荐用于频繁查询）
    print("=" * 60)
    print("示例 1: 使用持久连接（适用于频繁查询）")
    print("=" * 60)
    
    SERVER_HOST = "127.0.0.1"
    SERVER_PORT = 8888
    
    # 创建测试数据
    obs_dim = 7
    action_dim = 7
    T = 10
    
    try:
        # 创建客户端（建立连接）
        print(f"\n连接到服务器 {SERVER_HOST}:{SERVER_PORT}...")
        client = SafeIQLClient(SERVER_HOST, SERVER_PORT)
        
        # 获取服务器元数据
        metadata = client.get_server_metadata()
        print(f"服务器元数据: {metadata}")
        
        # 多次查询（复用连接）
        print(f"\n进行 10 次查询测试...")
        for i in range(10):
            s_env = np.random.randn(obs_dim).astype(np.float32)
            A = np.random.randn(T, action_dim).astype(np.float32)
            
            predicted_k = client.query_predict_k(s_env, A)
            print(f"  查询 {i+1}: predicted_k = {predicted_k}")
        
        # 健康检查
        health = client.check_health()
        print(f"\n健康检查: {health}")
        
        # 关闭连接
        client.close()
        print("\n✓ 测试完成！")
        
    except ConnectionError as e:
        print(f"\n✗ 连接错误: {e}")
        print("请确保服务器正在运行")
    except Exception as e:
        print(f"\n✗ 错误: {e}")
    
    # 示例 2：使用 with 语句自动管理连接
    print("\n" + "=" * 60)
    print("示例 2: 使用 with 语句（自动管理连接）")
    print("=" * 60)
    
    try:
        with SafeIQLClient(SERVER_HOST, SERVER_PORT) as client:
            for i in range(5):
                s_env = np.random.randn(obs_dim).astype(np.float32)
                A = np.random.randn(T, action_dim).astype(np.float32)
                
                predicted_k = client.query_predict_k(s_env, A)
                print(f"  查询 {i+1}: predicted_k = {predicted_k}")
        
        print("\n✓ with 语句测试完成（连接已自动关闭）")
        
    except ConnectionError as e:
        print(f"\n✗ 连接错误: {e}")
    except Exception as e:
        print(f"\n✗ 错误: {e}")
    
    # 示例 3：使用兼容性函数（一次性连接）
    print("\n" + "=" * 60)
    print("示例 3: 使用兼容性函数（不推荐频繁使用）")
    print("=" * 60)
    
    try:
        s_env = np.random.randn(obs_dim).astype(np.float32)
        A = np.random.randn(T, action_dim).astype(np.float32)
        
        # 每次调用都会建立新连接
        predicted_k = query_predict_k(s_env, A, SERVER_HOST, SERVER_PORT)
        print(f"预测的 k 值: {predicted_k}")
        
        health = check_health(SERVER_HOST, SERVER_PORT)
        print(f"服务器状态: {health.get('status')}")
        
        print("\n✓ 兼容性函数测试完成")
        
    except ConnectionError as e:
        print(f"\n✗ 连接错误: {e}")
    except Exception as e:
        print(f"\n✗ 错误: {e}")

