from datetime import datetime
import json
import logging
import warnings
import time
import multiprocessing

from typing import Any, Union, Dict

import websocket
import numpy as np

from expground.utils.helper import unpack


class JSONEncoder(json.JSONEncoder):
    """This custom encoder is to support serializing more complex data including
    numpy arrays, NaNs, and Infinity which don't have standarized handling according to the JSON spec.
    """

    def default(self, obj):
        if isinstance(obj, float):
            if np.isposinf(obj):
                obj = "Infinity"
            elif np.isneginf(obj):
                obj = "-Infinity"
            elif np.isnan(obj):
                obj = "NaN"
            return obj
        elif isinstance(obj, list):
            return [self.default(x) for x in obj]
        elif isinstance(obj, np.bool_):
            return super().encode(bool(obj))
        elif isinstance(obj, np.ndarray):
            return self.default(obj.tolist())

        return super().default(obj)


class Client:
    class QueueDone:
        pass

    def __init__(
        self,
        endpoint: str = None,
        wait_between_retries: float = 0.5,
        experiment_name: str = None,
    ):
        self._logger = logging.getLogger(self.__class__.__name__)
        if endpoint is None:
            endpoint = "ws://localhost:8081"

        self._logging_process = None
        self._logging_queue = None
        self._wait_between_retries = wait_between_retries
        self._experiment_name = experiment_name

        experiment_name = experiment_name or "experiment_"
        time_stamp = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-4]
        client_id = f"{experiment_name}_{time_stamp}"

        self._state_queue = multiprocessing.Queue()
        self._process = multiprocessing.Process(
            target=self._connect,
            args=(f"{endpoint}/experiments/{client_id}/broadcast", self._state_queue),
        )
        self._process.daemon = True
        self._process.start()
        self._logger.debug(f"Experiment client for `{experiment_name}` is running ...")

    @property
    def state_queue(self):
        return self._state_queue

    @staticmethod
    def _write_log_state(queue, path):
        with path.open("w", encoding="utf-8") as f:
            while True:
                state = queue.get()
                if type(state) is Client.QueueDone:
                    break

                if not isinstance(state, str):
                    state = unpack(state)
                    state = json.dumps(state, cls=JSONEncoder)

                f.write(f"{state}\n")

    @staticmethod
    def read_and_send(
        path: str,
        endpoint: str = "ws://localhost:8081",
        timestep_sec: float = 0.1,
        wait_between_retries: float = 0.5,
    ):
        client = Client(
            endpoint=endpoint,
            wait_between_retries=wait_between_retries,
        )
        with open(path, "r") as f:
            for line in f:
                line = line.rstrip("\n")
                time.sleep(timestep_sec)
                client._send_raw(line)

            client.teardown()
            logging.info("Finished Envision data replay")

    def _connect(
        self,
        endpoint: str,
        state_queue: multiprocessing.Queue,
        wait_between_retries: float = 0.05,
    ):

        connection_established = False
        warned_about_connection = False

        def optionally_serialize_and_write(state: Union[Any, str], ws):
            # if not already serialized
            if not isinstance(state, str):
                state = unpack(state)
                state = json.dumps(state, cls=JSONEncoder)

            ws.send(state)

        def on_close(ws, *args, **kwargs):
            self._logger.debug(f"Connection to logging server closed")

        def on_error(ws, error):
            nonlocal connection_established, warned_about_connection
            if str(error) == "'NoneType' object has no attribute 'sock'":
                return

            if connection_established:
                self._logger.error(
                    f"Connection to logging server terminated with: {error}"
                )
            else:
                logmsg = f"Connection to logging server failed with: {error}."
                self._logger.info(logmsg)

        def on_open(ws):
            nonlocal connection_established
            connection_established = True
            self._logger.debug("Connection established!")

            while True:
                state = state_queue.get()
                if type(state) is Client.QueueDone:
                    ws.close()
                    break

                optionally_serialize_and_write(state, ws)

        def run_socket(endpoint, wait_between_retries):
            nonlocal connection_established
            tries = 1
            while True:
                # TODO: use a real network socket instead (probably UDP)
                self._logger.debug(f"Try to launch websocket for trail={tries}")
                ws = websocket.WebSocketApp(
                    endpoint, on_error=on_error, on_close=on_close, on_open=on_open
                )

                with warnings.catch_warnings():
                    # XXX: websocket-client library seems to have leaks on connection
                    #      retry that cause annoying warnings within Python 3.8+
                    warnings.filterwarnings("ignore", category=ResourceWarning)
                    ws.run_forever()

                if not connection_established:
                    self._logger.info(f"Attempt {tries} to connect to logging server.")
                else:
                    # No information left to send, connection is likely done
                    if state_queue.empty():
                        break
                    # When connection lost, retry again every 3 seconds
                    wait_between_retries = 3
                    self._logger.info(
                        f"Connection to logging server lost. Attempt {tries} to reconnect."
                    )

                tries += 1
                time.sleep(wait_between_retries)

        run_socket(endpoint, wait_between_retries)

    def send(self, state: Any):
        if self._process.is_alive():
            self._state_queue.put(state)
        if self._logging_process:
            self._logging_queue.put(state)

    def _send_raw(self, state: str):
        """Skip serialization if we already have serialized data. This is useful if
        we are reading from file and forwarding through the websocket.
        """

        self._state_queue.put(state)

    def teardown(self):
        if self._state_queue:
            self._state_queue.put(Client.QueueDone())
            self._process.join(timeout=3)
            self._process = None
            self._state_queue.close()
            self._state_queue = None

        if self._logging_process and self._logging_queue:
            self._logging_queue.put(Client.QueueDone())
            self._logging_process.join(timeout=3)
            self._logging_process = None
            self._logging_queue.close()
            self._logging_queue = None


# TODO(): I wanna use a client for an experiment, users can
#   get client by name, then return an interface to do logging
class ClientSet:
    class Handler:
        def __init__(self, experiment_name: str, queue: multiprocessing.Queue):
            """Create an handler instance.

            Args:
                experiment_name (str): Experiment name.
                queue (multiprocessing.Queue): State queue from a experiment client.
            """

            self._queue = queue
            self._experiment_name = experiment_name
            self._logger = logging.getLogger(f"{self.__class__.__name__}_{time.time()}")

        def send(self, msg: Any, global_step: int = None) -> None:
            """Send message to client in a format of [name, message, global_step, wall_time]

            Args:
                msg (Any): Message.
                global_step (int, optional): Global step. Defaults to None.
            """

            # message
            wall_time = time.time()
            # self._logger.debug(f"handler={self._experiment_name} received message: {msg}")
            self._queue.put([self._experiment_name, msg, global_step, wall_time])

    def __init__(self):
        self._clients: Dict[str, Client] = dict()

    def register(
        self,
        experiment_name: str,
        endpoint: str = None,
        wait_between_retries: float = 0.5,
    ) -> None:
        """Register experiment client with given configuration.

        Args:
            experiment_name (str): Experiment name
            endpoint (str, optional): Logger server address. Defaults to None.
            wait_between_retries (float, optional): Wait secons. Defaults to 0.5.
        """

        assert experiment_name not in self._clients
        self._clients[experiment_name] = Client(
            endpoint=endpoint,
            wait_between_retries=wait_between_retries,
            experiment_name=experiment_name,
        )

    def get_handler(self, experiment_name: str, sub_name: str = None) -> "Handler":
        """Return a logging handler which coordinates with experiment client.

        Returns:
            Handler: A handler for experiment client named as `experiment_name`
        """
        assert experiment_name in self._clients
        return ClientSet.Handler(
            experiment_name, self._clients[experiment_name].state_queue
        )

    def close_all(self):
        for client in self._clients.values():
            client.teardown()

        del self._clients
        self._clients = {}

    def close(self, name: str):
        self._clients[name].teardown()
        self._clients.pop(name)


logging_client_set = ClientSet()
