import asyncio
from collections import defaultdict
import time
import logging
import sys
import threading
import signal
import json
import random
import multiprocessing
import os

from typing import Any, List, Union, Awaitable, Optional, Dict

import tornado
import tornado.web
import tornado.websocket
import tornado.ioloop
import tornado.httputil

from tornado.websocket import WebSocketClosedError


WEB_CLIENT_RUN_LOOPS = {}
DATA_FRAMES_GROUP = {}


class AllowCORSMixin:
    def set_default_headers(self):
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers", "x-requested-with")
        self.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
        self.set_header("Content-Type", "application/octet-stream")

    def options(self):
        self.set_status(204)
        self.finish()


class DataFrame:
    """Data frame defines the structure of transferred logging data from client. Esentials of a dataframe including `wall_time`, `global_step` and `data`"""

    def __init__(
        self,
        data: Any,
        wall_time: float,
        global_step: int,
        next: "DataFrame" = None,
        group: str = "default",
    ):
        self._wall_time = wall_time
        self._group = group
        self._global_step = global_step
        self._data = data
        self._size = sys.getsizeof(data)
        self.next_ = next

    @property
    def wall_time(self) -> float:
        return self._wall_time

    @property
    def group(self) -> str:
        return self._group

    @property
    def global_step(self) -> int:
        return self._global_step

    @property
    def size(self) -> int:
        return self._size

    @property
    def data(self) -> Any:
        return self._data

    @data.setter
    def data(self, data: Any):
        self._data = data
        self._size = sys.getsizeof(data)

    @staticmethod
    def from_string(message: str) -> "DataFrame":
        """Create DataFrame instance from a string.

        Raises:
            NotImplementedError: [description]
        """

        # parse string and save it
        message = json.loads(message)
        # logging.debug(f"received message: {message}.")
        return DataFrame(
            data=message[1],
            wall_time=message[-1],
            global_step=message[2],
            group=message[0],
        )


class DataFrames:
    """DataFrame buffer"""

    def __init__(self, max_capacity_mb: float = 500.0):
        self._max_capacity = max_capacity_mb
        self._frames = []
        self._wall_times = []

        # XXX: global step is actually the real index of data frame
        #   if the received DataFrame does not provide a valid global step
        #   then the global step will be replaced with a sorted wall_time list
        # self._global_steps = []

    def append(self, data_frame: DataFrame):
        """Append new dataframe.

        Args:
            data_frame (DataFrame): data frame.
        """

        self._enforce_max_capacity()
        if len(self._frames) > 0:
            self._frames[-1].next_ = data_frame
        # if data_frame.global_step is not None:
        #     self._global_steps.append(data_frame.global_step)
        self._wall_times.append(data_frame.wall_time)

    def __call__(self, global_step: int = None) -> DataFrame:
        """Return a data frame by global step.

        Args:
            global_step (int, optional): Global step. Defaults to None.

        Returns:
            DataFrame: A data frame instance
        """

        assert global_step <= len(self._frames), "Out of index error."
        return self._frames[global_step]

    def _enforce_max_capacity(self):
        """Sample random frames and clear their data to ensure we're under the max
        capacity size. Enable it for mu
        """

        bytes_to_mb = 1e-6
        start_frames_to_keep = 1
        end_frames_to_keep = 10
        sizes = [frame.size for frame in self._frames]

        while (
            len(self._frames) > start_frames_to_keep + end_frames_to_keep
            and sum(sizes) * bytes_to_mb > self._max_capacity
        ):
            # XXX: randint(1, ...), we skip the start frame because that is a "safe"
            #      frame we can always rely on being available.
            idx_to_delete = random.randint(
                1, len(self._frames) - 1 - end_frames_to_keep
            )
            self._frames[idx_to_delete - 1].next_ = self._frames[idx_to_delete].next_
            del sizes[idx_to_delete]
            del self._frames[idx_to_delete]
            del self._timestamps[idx_to_delete]


class BroadcastWebSocket(tornado.websocket.WebSocketHandler):
    """This websocket receives the logging state, save it to local storage and broadcasts it to
    all web clients that have open web sockets via the `ExperimentSocket` handler.
    """

    def initialize(self, max_capacity_mb: float):
        self._logger = logging.getLogger(self.__class__.__name__)
        self._max_capacity_mb = max_capacity_mb

    async def open(self, experiment_id: str) -> None:
        """Create a dataframe buffer for experiment tagged with `experiment_id`.

        Args:
            experiment_id (str): Experiment id.
        """

        self._logger.debug(f"Broadcast websocket opened for experiment={experiment_id}")
        self._data_frames_group: Dict[str, DataFrames] = {}
        # DataFrames(max_capacity_mb=self._max_capacity_mb)
        self._experiment_id = experiment_id

        DATA_FRAMES_GROUP[experiment_id] = self._data_frames_group
        WEB_CLIENT_RUN_LOOPS[experiment_id] = set()

    def on_close(self) -> None:
        """Handling closing event."""

        self._logger.debug(
            f"Broadcast websocket closed for experiment={self._experiment_id}"
        )

        del WEB_CLIENT_RUN_LOOPS[self._experiment_id]

        # FIXME: check whether there is any remained dataframe in the buffer.
        del DATA_FRAMES_GROUP[self._experiment_id]

    async def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
        """Accepet new message (logging data) from experiment tagged with `experiment_id` and append
        it to the frame buffer. Furthermore, the coming dataframe will be stored in local storage.

        Args:
            message (Union[str, bytes]): Message from running experiment.

        Returns:
            Optional[Awaitable[None]]: None
        """

        data_frame = DataFrame.from_string(message)
        if self._data_frames_group.get(data_frame.group, None) is None:
            self._data_frames_group[data_frame.group] = DataFrames(
                max_capacity_mb=self._max_capacity_mb
            )
        self._data_frames_group[data_frame.group].append(data_frame=data_frame)


class WebClientRunLoop:
    """Web client run loop is responsible for a specific experiment logging monitor. Handling buffer flushing."""

    def __init__(
        self,
        data_frames_group: DataFrames,
        web_client_handler: tornado.websocket.WebSocketHandler,
        timestep_sec: float,
    ):
        self._logger = logging.getLogger(__class__.__name__)
        self._client = web_client_handler
        self._timestep_sec = timestep_sec
        self._data_frames_group = data_frames_group
        self._threads = {}

    def stop(self) -> None:
        """Stop event thread, timeout for 3 seconds."""

        if len(self._threads) > 0:
            for _thread in self._threads.values():
                _thread.join(timeout=3)
            self._threads = {}

    def run_forever(self):
        """Start threads for each data frame buffer."""

        def run_loop(frames_desc: str):
            frames = self._data_frames_group[frames_desc]
            frame_ptr = None

            while frame_ptr is None:
                time.sleep(0.5 * self._timestep_sec)
                frame_ptr = frames.start_frame
                frames_to_send = [frame_ptr]

            while True:
                assert len(frames_to_send) > 0, "No avaiable datas to send"
                closed = self._push_data_frames(frames_to_send)
                if closed:
                    self._logger.debug(
                        f"Socket closed for experiment={self._client.experiment_id}, exiting"
                    )
                    return

                frame_ptr, frames_to_send = self._wait_for_next_frame(frame_ptr)

        def sync_run_forever(frames_desc: str):
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

            loop.run_until_complete(run_loop(frames_desc))
            loop.close()

        self._threads = {
            desc: threading.Thread(target=sync_run_forever, args=(desc,), daemon=True)
            for desc in self._data_frames_group
        }
        _ = [_thread.start() for _thread in self._threads.values()]

    def _push_data_frames(self, data_frames: List[DataFrame]):
        try:
            cleaned_data = [
                {
                    "walltime": data_frame.wall_time,
                    "global_step": data_frame.global_step,
                    "content": data_frame.data,
                }
                for data_frame in data_frames
            ]
            # TODO: both json and tensorboard
            self._client.write_message(json.dumps(cleaned_data))
            return False
        except WebSocketClosedError:
            return True

    def _calculate_frame_delay(self, frame_ptr):
        # we may want to be more clever here in the future...
        return 0.5 * self._timestep_sec if not frame_ptr.next_ else 0

    def _wait_for_next_frame(self, frame_ptr):
        FRAME_BATCH_SIZE = 100  # limit the batch size for bandwidth and to allow breaks for seeks to be handled
        while True:
            delay = self._calculate_frame_delay(frame_ptr)
            time.sleep(delay)
            frames_to_send = []
            while frame_ptr.next_ and len(frames_to_send) <= FRAME_BATCH_SIZE:
                frame_ptr = frame_ptr.next_
                frames_to_send.append(frame_ptr)
            if len(frames_to_send) > 0:
                return frame_ptr, frames_to_send


class ExperimentStateWebSocket(tornado.websocket.WebSocketHandler):
    def initialize(self):
        self._logger = logging.getLogger(self.__class__.__name__)

    def check_origin(self, origin):
        return True

    async def open(self, experiment_id):
        if experiment_id not in WEB_CLIENT_RUN_LOOPS:
            raise tornado.web.HTTPError(404)

        # control the buffer flush frequency
        timestep_sec = 0.1

        self._run_loop = WebClientRunLoop(
            data_frames_group=DATA_FRAMES_GROUP[experiment_id],
            timestep_sec=timestep_sec,
            web_client_handler=self,
        )

        self._logger.debug(
            f"Experiment state websocket opened for experiment={experiment_id}"
        )
        WEB_CLIENT_RUN_LOOPS[experiment_id].add(self._run_loop)

        self._run_loop.run_forever()

    def on_close(self):
        self._logger.debug(f"Experiment state websocket closed")
        for run_loop in WEB_CLIENT_RUN_LOOPS.values():
            if self in run_loop:
                self._run_loop.stop()
                run_loop.remove(self._run_loop)


class ExperimentsHandler(AllowCORSMixin, tornado.web.RequestHandler):
    async def get(self):
        response = json.dumps({"experiments": list(WEB_CLIENT_RUN_LOOPS.keys())})
        self.write(response)


class MainHandler(tornado.web.RequestHandler):
    def __init__(
        self,
        application: tornado.web.Application,
        request: tornado.httputil.HTTPServerRequest,
        **kwargs: Any,
    ) -> None:
        super().__init__(application, request, **kwargs)
        self._logger = logging.getLogger(self.__class__.__name__)

    def get(self):
        # TODO: display in frontend
        self._logger.warning(
            "Main handler got a GET request without available web client is provided yet."
        )


def on_shutdown():
    logging.debug("Shutting down experiment monitor")
    tornado.ioloop.IOLoop.current().stop()


def make_app(max_capacity_mb: float):
    return tornado.web.Application(
        [
            (r"/", MainHandler),
            # list all available experiments
            (r"/experiments", ExperimentsHandler),
            # show state of a experiment tagged with `experiment_id`
            (r"/experiments/(?P<experiment_id>\w+)/state", ExperimentStateWebSocket),
            (
                r"/experiments/(?P<experiment_id>\w+)/broadcast",
                BroadcastWebSocket,
                {"max_capacity_mb": max_capacity_mb},
            ),
        ]
    )


def run(
    port: int = 8081,
    max_capacity_mb: float = 500.0,
    queue: multiprocessing.Queue = None,
):
    app = make_app(max_capacity_mb)
    app.listen(port)

    logging.debug(f"Start experiment monitor on port={port}")

    ioloop = tornado.ioloop.IOLoop.current()

    signal.signal(
        signal.SIGINT,
        lambda signal, _: ioloop.add_callback_from_signal(on_shutdown),
    )

    signal.signal(
        signal.SIGTERM, lambda signal, _: ioloop.add_callback_from_signal(on_shutdown)
    )

    ioloop.start()


class LoggingServerProcess:
    def __init__(self, port: int = 8081, max_capacity_mb: float = 500.0):
        # XXX: the default process start method has been changed to 'spawn' in macOS,
        #   since the fork method will lead to crashes of subprocesses. See: https://bugs.python.org/issue33725.
        #   Besides, the spwan method is is rather slow compared to using fork or forkserver.
        os_type = os.uname().sysname
        if "Linux" == os_type:
            _start_method = "fork"
        elif "Darwin" == os_type:
            _start_method = "spawn"
        else:
            raise TypeError(f"Unsupported os type: {os_type}")
        multiprocessing.set_start_method(_start_method)

        self._ip = "localhost"
        self._port = port
        self._queue = multiprocessing.Queue()
        self._proc = multiprocessing.Process(
            target=run, args=(port, max_capacity_mb, self._queue)
        )
        self._proc.daemon = True
        self._proc.start()

    @property
    def endpoint(self) -> str:
        """Return the websocket address."""

        return f"ws://{self._ip}:{self._port}"

    def send(self, msg: str):
        self._queue.put(msg)

    def close(self):
        self._proc.join(timeout=3)
        self._process = None
        self._queue.close()
        self._queue = None
