import asyncio
import logging
import threading
from typing import Callable, Any

import msgpack_numpy
import websockets

logger = logging.getLogger("async_ws_client")


class AsyncWebsocketClient:
    def __init__(self, host: str = "localhost", port: int = 8765):
        self.uri = f"ws://{host}:{port}"
        self.ws = None
        self.packer = msgpack_numpy.Packer(use_bin_type=True)
        self.unpacker = msgpack_numpy.Unpacker()
        self.loop = None
        self.thread = None
        self.message_handler = None
        self.running = False
        self._handler_lock = threading.Lock()

    def set_message_handler(self, handler: Callable[[Any], None]):
        with self._handler_lock:
            self.message_handler = handler

    def _run_client(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        self.loop.run_until_complete(self.main_loop())

    async def main_loop(self):
        self.running = True
        while self.running:
            try:
                async with websockets.connect(self.uri, compression=None, max_size=2**24) as websocket:
                    self.ws = websocket
                    await self.on_connected()
                    await self.receive_messages()
            except (websockets.exceptions.ConnectionClosedError, ConnectionRefusedError, OSError) as e:
                logger.warning(f"Connection error: {e}. Reconnecting in 1s...")
                await asyncio.sleep(1)
            finally:
                self.ws = None

    async def on_connected(self):
        logger.info(f"Connected to {self.uri}")

    async def receive_messages(self):
        while self.ws and self.running:
            try:
                message = await self.ws.recv()
                self.unpacker.feed(message)
                for unpacked in self.unpacker:
                    with self._handler_lock:
                        if self.message_handler:
                            self.message_handler(unpacked)
            except websockets.exceptions.ConnectionClosed:
                break
            except Exception as e:
                logger.error(f"Error receiving message: {e}")

    async def send(self, data):
        if self.ws:
            try:
                packed_data = self.packer.pack(data)
                await self.ws.send(packed_data)
            except websockets.exceptions.ConnectionClosed:
                logger.warning("Tried to send message, but connection is closed.")
        else:
            logger.warning("Client not connected, cannot send message.")

    def start(self):
        if not self.running:
            self.running = True
            if self.thread is None or not self.thread.is_alive():
                self.thread = threading.Thread(target=self._run_client)
                self.thread.start()

    def stop(self):
        self.running = False
        if self.ws and self.loop and self.loop.is_running():
            asyncio.run_coroutine_threadsafe(self.ws.close(), self.loop)
        if self.thread and self.thread.is_alive():
            self.thread.join()

    def send_message(self, data):
        if self.loop and self.loop.is_running():
            asyncio.run_coroutine_threadsafe(self.send(data), self.loop)
        else:
            logger.warning("Event loop is not running. Cannot send message.")
