import asyncio
import http
import logging
import time
import traceback

from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
import websockets.asyncio.server as _server
import websockets.frames

logger = logging.getLogger(__name__)


class WebsocketPolicyServer:
    """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.

    Currently only implements the `load` and `infer` methods.
    """

    def __init__(
        self,
        policy: _base_policy.BasePolicy,
        host: str = "0.0.0.0",
        port: int | None = None,
        metadata: dict | None = None,
        *,
        compare_and_select: bool = False,
        default_num_candidates: int | None = None,
    ) -> None:
        self._policy = policy
        self._host = host
        self._port = port
        self._metadata = metadata or {}
        self._compare_and_select = compare_and_select
        self._default_num_candidates = default_num_candidates
        logging.getLogger("websockets.server").setLevel(logging.INFO)

    def serve_forever(self) -> None:
        asyncio.run(self.run())

    async def run(self):
        async with _server.serve(
            self._handler,
            self._host,
            self._port,
            compression=None,
            max_size=None,
            process_request=_health_check,
        ) as server:
            await server.serve_forever()

    async def _handler(self, websocket: _server.ServerConnection):
        logger.info(f"Connection from {websocket.remote_address} opened")
        packer = msgpack_numpy.Packer()

        # Augment metadata with server-level compare-and-select defaults for client adaptation
        metadata = dict(self._metadata)
        metadata.setdefault("compare_and_select_enabled", bool(self._compare_and_select))
        if self._default_num_candidates is not None:
            metadata.setdefault("default_num_candidates", int(self._default_num_candidates))
        await websocket.send(packer.pack(metadata))

        prev_total_time = None
        while True:
            try:
                start_time = time.monotonic()
                obs = msgpack_numpy.unpackb(await websocket.recv())
                # Apply server-level compare-and-select defaults if enabled
                if self._compare_and_select:
                    obs.setdefault("compare_and_select", True)
                    if self._default_num_candidates is not None and "num_candidates" not in obs:
                        obs["num_candidates"] = self._default_num_candidates

                infer_time = time.monotonic()
                # Allow optional num_candidates to trigger compare-and-select path
                action = self._policy.infer(obs)
                """       
                action2 = self._policy.infer(obs)
                # Compare two actions if both have 'actions' key with numeric arrays
                try:
                    a1 = action.get("actions", None)
                    a2 = action2.get("actions", None)
                    if a1 is not None and a2 is not None:
                        # Support lists or numpy arrays
                        import numpy as _np

                        a1_arr = _np.asarray(a1, dtype=_np.float32)
                        a2_arr = _np.asarray(a2, dtype=_np.float32)
                        # If batch dimension exists, use the first element for trajectory distances and sampling
                        if a1_arr.ndim == 3 and a1_arr.shape[0] > 0:
                            a1_seq = a1_arr[0]
                            a2_seq = a2_arr[0]
                        else:
                            a1_seq = a1_arr
                            a2_seq = a2_arr

                        # Ensure sequences are 2D: [T, D]
                        if a1_seq.ndim == 1:
                            a1_seq = a1_seq[:, None]
                        if a2_seq.ndim == 1:
                            a2_seq = a2_seq[:, None]

                        # Discrete Fréchet distance (iterative DP)
                        def _discrete_frechet(P, Q):
                            n, m = P.shape[0], Q.shape[0]
                            ca = _np.full((n, m), _np.inf, dtype=_np.float32)
                            for i in range(n):
                                for j in range(m):
                                    d = float(_np.linalg.norm(P[i] - Q[j]))
                                    if i == 0 and j == 0:
                                        ca[i, j] = d
                                    elif i > 0 and j == 0:
                                        ca[i, j] = max(ca[i - 1, 0], d)
                                    elif i == 0 and j > 0:
                                        ca[i, j] = max(ca[0, j - 1], d)
                                    else:
                                        ca[i, j] = min(
                                            max(ca[i - 1, j], d),
                                            max(ca[i - 1, j - 1], d),
                                            max(ca[i, j - 1], d),
                                        )
                            return float(ca[n - 1, m - 1])

                        # DTW distance (L2 cost)
                        def _dtw_distance(P, Q):
                            n, m = P.shape[0], Q.shape[0]
                            dtw = _np.full((n + 1, m + 1), _np.inf, dtype=_np.float32)
                            dtw[0, 0] = 0.0
                            for i in range(1, n + 1):
                                for j in range(1, m + 1):
                                    cost = float(_np.linalg.norm(P[i - 1] - Q[j - 1]))
                                    dtw[i, j] = cost + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1])
                            return float(dtw[n, m])

                        frechet = _discrete_frechet(a1_seq, a2_seq)
                        dtw = _dtw_distance(a1_seq, a2_seq)
                        logger.info("Trajectory distances: Frechet=%.6f, DTW=%.6f", frechet, dtw)

                        # Randomly sample one timestep and show both actions at that position
                        T = int(a1_seq.shape[0])
                        if T > 0:
                            idx = int(_np.random.randint(T))
                            v1 = a1_seq[idx]
                            v2 = a2_seq[idx]
                            logger.info("Sample at index %d: action1=%s, action2=%s", idx, v1.tolist(), v2.tolist())
                except Exception:
                    logger.exception("Failed to compute action diff stats")
                """
                infer_time = time.monotonic() - infer_time

                action["server_timing"] = {
                    "infer_ms": infer_time * 1000,
                }
                if prev_total_time is not None:
                    # We can only record the last total time since we also want to include the send time.
                    action["server_timing"]["prev_total_ms"] = prev_total_time * 1000

                await websocket.send(packer.pack(action))
                prev_total_time = time.monotonic() - start_time

            except websockets.ConnectionClosed:
                logger.info(f"Connection from {websocket.remote_address} closed")
                break
            except Exception:
                await websocket.send(traceback.format_exc())
                await websocket.close(
                    code=websockets.frames.CloseCode.INTERNAL_ERROR,
                    reason="Internal server error. Traceback included in previous frame.",
                )
                raise


def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
    if request.path == "/healthz":
        return connection.respond(http.HTTPStatus.OK, "OK\n")
    # Continue with the normal request handling.
    return None
