# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC
from dataclasses import dataclass
from queue import Empty as QueueEmpty
from queue import Queue
from typing import Dict, List, Optional

import torch

from fairscale.internal.object import pyobject_to_tensor, tensor_to_pyobject
from fairscale.nn.model_parallel import get_pipeline_parallel_group

from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors

MESSAGE_TENSOR_SIZE = 1024

MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)]


def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors:
    if input_device is None:
        return tensors
    else:
        return tuple(t.to(input_device) for t in tensors)


def rpc_push_queue(message: PipeMessage) -> None:
    globals()["MessageQueues"][message.queue_name].put(message)


@dataclass(frozen=True)
class Transport(ABC):
    worker_map: Optional[Dict[int, str]]
    input_device: InputDevice

    def recv_message(self, queue_name: int, *, nowait: bool = False) -> PipeMessage:
        message = self.recv_message_header(queue_name, nowait)
        return self.recv_message_tensors(message)

    def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
        ...

    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        ...

    def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
        ...

    def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
        ...


def MakeTransport(use_rpc: bool, worker_map: Optional[Dict[int, str]], input_device: InputDevice) -> Transport:
    if use_rpc:
        if worker_map is None:
            raise ValueError("'RpcTransport' requires 'worker_map' to be set")
        return RpcTransport(worker_map, input_device)
    else:
        return SendRecvTransport(worker_map, input_device)


class RpcTransport(Transport):
    def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
        message.tensors = tuple(t.cpu() for t in message.tensors)
        assert self.worker_map
        name = self.worker_map[message.dest]
        if sync:
            torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,))
        else:
            torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,))

    def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
        queue = MessageQueues[queue_name]
        if nowait:
            result = queue.get_nowait()
        else:
            result = queue.get()
        result.tensors = to_input_device(result.tensors, self.input_device)
        return result

    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        # Tensors already contained within message
        message.tensors = to_input_device(message.tensors, self.input_device)
        return message

    def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
        """Receive a message with a known microbatch index, and handle out-of-order
        messages by placing them back on the queue"""

        queue = globals()["MessageQueues"][queue_name]
        out_of_order: List[PipeMessage] = []
        while True:
            message = self.recv_message(queue_name)
            got_index = message.args
            value = message.tensors
            if got_index == index:
                for b in out_of_order:
                    queue.put(b)
                return value
            else:
                out_of_order.append(message)


class SendRecvTransport(Transport):
    def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
        tensors = message.tensors
        message.tensors = tuple()
        torch.cuda.current_stream().synchronize()
        if not skip_header:
            message.tensor_shapes = [t.size() for t in tensors]
            message.tensor_dtypes = [t.dtype for t in tensors]
            torch.distributed.send(
                pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(),
                message.dest,
                tag=message.queue_name,
                group=get_pipeline_parallel_group(),
            )
        for index, t in enumerate(tensors):
            if t.device.type == "cpu":
                t = t.cuda()
            torch.distributed.send(
                t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group()
            )

    def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
        # FIXME(handle nowait)
        if nowait:
            raise QueueEmpty
        tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device)
        torch.cuda.current_stream().synchronize()
        torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group())
        torch.cuda.current_stream().synchronize()
        return tensor_to_pyobject(tensor)

    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        torch.cuda.current_stream().synchronize()

        message_tensors = []
        for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)):
            t = torch.empty(*shape, dtype=dtype, device=self.input_device)
            torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group())
            message_tensors.append(t)

        message.tensors = tuple(message_tensors)

        torch.cuda.current_stream().synchronize()
        return message

    def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
        """Receive a message with a known microbatch index, and handle out-of-order
        messages by placing them back on the queue"""

        message = self.recv_message(queue_name)
        assert message.args == index
        return message.tensors
