import asyncio
import logging
import socket
import traceback
from typing import Set

import msgpack_numpy
import websockets

logger = logging.getLogger("async_ws_server")


class AsyncWebsocketServer:
    def __init__(self, host: str = "0.0.0.0", port: int = 8765):
        self._host = host
        self._port = port
        self._clients: Set[websockets.WebSocketServerProtocol] = set()
        self._packer = msgpack_numpy.Packer()
        self._unpacker = msgpack_numpy.Unpacker()
        self._is_running = False
        self.loop = None
        self.server = None
        self.server_task_instance = None

    async def _register(self, websocket: websockets.WebSocketServerProtocol):
        self._clients.add(websocket)
        logger.info(f"Client connected: {websocket.remote_address}. Total clients: {len(self._clients)}")

    async def _unregister(self, websocket: websockets.WebSocketServerProtocol):
        self._clients.remove(websocket)
        logger.info(f"Client disconnected: {websocket.remote_address}. Total clients: {len(self._clients)}")

    async def _handler(self, websocket: websockets.WebSocketServerProtocol):
        await self._register(websocket)
        try:
            await self.on_client_connected(websocket)
            async for message in websocket:
                try:
                    self._unpacker.feed(message)
                    for unpacked in self._unpacker:
                        await self.on_message_received(websocket, unpacked)
                except Exception as e:
                    logger.error(f"Error processing message: {e}")
                    await self.send_error(websocket, traceback.format_exc())
        except websockets.exceptions.ConnectionClosed as e:
            logger.info(f"Connection closed with {websocket.remote_address}: {e}")
        finally:
            await self._unregister(websocket)

    async def on_client_connected(self, websocket: websockets.WebSocketServerProtocol):
        pass

    async def on_message_received(self, websocket: websockets.WebSocketServerProtocol, message):
        raise NotImplementedError

    async def broadcast(self, message):
        if not self._clients:
            return
        packed_message = self._packer.pack(message)
        await asyncio.wait([client.send(packed_message) for client in self._clients])

    async def send(self, websocket: websockets.WebSocketServerProtocol, message):
        await websocket.send(self._packer.pack(message))

    async def send_error(self, websocket: websockets.WebSocketServerProtocol, error_message: str):
        error_payload = {"error": error_message}
        await websocket.send(self._packer.pack(error_payload))

    async def run(self):
        logger.info(f"Attempting to start server on {self._host}:{self._port}...")
        try:
            # Manually create a socket to set SO_REUSEADDR
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.bind((self._host, self._port))

            self.server = await websockets.serve(
                self._handler,
                sock=sock,
                compression=None,
                max_size=2**24,  # Increase max message size
            )
            self._is_running = True
            logger.info(f"Server successfully started and listening on {self._host}:{self._port}")

            # Run server_task and wait for the server to close concurrently
            self.server_task_instance = asyncio.create_task(self.server_task())
            await self.server.wait_closed()
            self.server_task_instance.cancel()
            try:
                await self.server_task_instance
            except asyncio.CancelledError:
                pass  # Task cancellation is expected

        except Exception as e:
            logger.error(f"Server run failed: {e}", exc_info=True)
            self._is_running = False
        finally:
            logger.info(f"Server on {self._host}:{self._port} has stopped.")

    async def server_task(self):
        # This can be overridden by subclasses to run a periodic task
        pass

    def serve_forever(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        try:
            self.loop.run_until_complete(self.run())
        finally:
            self.loop.close()
            self.loop = None

    def stop(self):
        if self.server and self.loop:
            self.loop.call_soon_threadsafe(self.server.close)