import collections
import logging
from typing import List

import numpy as np
import torch
import matplotlib.pyplot as plt
from abc import ABCMeta, abstractmethod

from tqdm import tqdm


class Channel(metaclass=ABCMeta):
    """
    Models a potentially lossy broadcast communication channel.
    """

    @abstractmethod
    def simulate_transmission(self, agent_msg_sizes: torch.Tensor):
        """
        Simulates a transmission of messages with the given sizes.

        :param agent_msg_sizes: A tensor of message sizes with shape (batch_size, nb_agents)
        :return: tensor of shape (batch_size, nb_agents) that is 1 when the respective message has been dropped.
        """
        ...

    def get_transmitted_messages(self, messages: torch.Tensor, message_sizes: torch.Tensor):
        """
        Simulates message transmission using the given message sizes and returns masked messages.

        :param messages: The messages that are sent (sender, batch, message)
        :param message_sizes: The message sizes (sender, batch)
        :return: the messages that can be received. Messages with collisions have been dropped. Plus debug info
        """
        if messages is None or message_sizes is None:
            logging.warning("Cannot transmit any messages without given messages or message sizes.")
            return None

        # transpose message sizes from (sender, batch_size) to (batch_size, sender)
        message_sizes_t = message_sizes.transpose(0, 1)

        # calculate drops
        dropped_messages = self.simulate_transmission(message_sizes_t).transpose(0, 1)
        # messages of size 0 are never dropped
        dropped_messages *= (message_sizes != 0)

        # (sender, batch_size) => reduce sender dimension
        num_dropped = dropped_messages.sum(dim=0)
        abs_channel_util = ((1 - dropped_messages) * message_sizes).sum(dim=0)

        # apply dropout to messages (content is zeroed out)
        return (
            (1 - dropped_messages.unsqueeze(-1)) * messages,
            num_dropped,
            abs_channel_util
        )

    def __call__(self, *args, **kwargs):
        return self.get_transmitted_messages(*args)

    def __repr__(self):
        return f"{type(self).__name__}({str(self.__dict__)})"


class PerfectChannel(Channel):
    """
    A perfect channel without any losses or collisions.
    """
    def simulate_transmission(self, agent_msg_sizes: torch.Tensor):
        return torch.zeros_like(agent_msg_sizes)

    def __repr__(self):
        return f"{type(self).__name__}"


class BlockingChannel(Channel):
    """
    A channel that blocks all communication.
    """
    def simulate_transmission(self, agent_msg_sizes: torch.Tensor):
        return torch.ones_like(agent_msg_sizes) * (agent_msg_sizes != 0)

    def __repr__(self):
        return f"{type(self).__name__}"


class SelectiveChannel(Channel):
    """
    A channel that statically *blocks* communication for some agents.
    """
    def __init__(self, blocked_agents: torch.Tensor):
        """
        :param blocked_agents: A tensor with shape (num_agents,). Block messages of an agent if its value is 1.
        """
        self.blocked_agents = blocked_agents
        assert len(self.blocked_agents.shape) == 1

    def simulate_transmission(self, agent_msg_sizes: torch.Tensor):
        # agent_msg_sizes is (batch_dim, agent_dim)
        return torch.ones_like(agent_msg_sizes) * self.blocked_agents.unsqueeze(0)

    def __repr__(self):
        return f"{type(self).__name__}"


class StochasticChannel(Channel):
    """
    A stochastic channel with a limited number of slots. Messages are inserted randomly and are dropped if multiple
    messages are assigned to the same slot.
    """
    def __init__(self, nb_slots, use_msg_size_spacing=True, verbose=False):
        """

        :param nb_slots: the number of slots (same unit as message size)
        :param use_msg_size_spacing: Controls how messages are inserted into the channel. If True, messages are
               inserted starting from a slot that is a multiple of its size. Otherwise, messages are inserted at any
               slot as long as they completely fit into the channel.
        :param verbose: print debugging information
        """
        self.nb_slots = nb_slots
        self.use_msg_size_spacing = use_msg_size_spacing
        self.verbose = verbose

    def simulate_transmission(self, agent_msg_sizes: torch.Tensor):
        batch_size = agent_msg_sizes.shape[0]
        nb_agents = agent_msg_sizes.shape[1]

        if self.use_msg_size_spacing:
            nb_available_slots = torch.floor(self.nb_slots / agent_msg_sizes)
            chosen_slots = torch.floor(torch.rand((batch_size, nb_agents)) * nb_available_slots) * agent_msg_sizes
            # messages of size 0 (nan because of division by zero) are "put" into slot 0 (they won't occupy that slot)
            chosen_slots = chosen_slots.nan_to_num(0)
        else:
            # assume that we have nb_slots and put our messages in there randomly but they should still fit in
            max_chosen_slot = torch.clamp_min(self.nb_slots - agent_msg_sizes, 0)  # is inclusive!
            # note that chosen_slots can contain the (invalid) slot nb_slots iff the msg size is 0
            chosen_slots = torch.floor(torch.rand((batch_size, nb_agents)) * (max_chosen_slot + 1))

        # sort the chosen slots
        sorted_slots, sorted_slot_indices = torch.sort(chosen_slots, dim=-1)
        sorted_slot_indices_inverted = sorted_slot_indices.argsort()
        # .. and the message sizes accordingly
        sorted_slots_msg_sizes = torch.gather(input=agent_msg_sizes, dim=-1, index=sorted_slot_indices)

        # this tensor will tell us whether there has been a collision for each message
        sorted_collisions = torch.zeros((batch_size, nb_agents), dtype=torch.float)

        for r in range(0, nb_agents):
            # following messages will collide with this one if they
            # start before this message has been completely transmitted
            collision_vector = sorted_slots[:, r] + sorted_slots_msg_sizes[:, r]

            # make sure to drop messages that exceed the channel's capacity
            msg_size_exceeds_capacity = sorted_slots_msg_sizes[:, r] > self.nb_slots
            sorted_collisions[:, r] = (sorted_collisions[:, r] + msg_size_exceeds_capacity) > 0

            for a in range(r + 1, nb_agents):
                # check if the messages from r overlap the messages from a
                collisions_r_a = (
                    # < because we want to keep messages from r with msg size 0
                        (sorted_slots[:, a] < collision_vector)
                        # messages from a with size 0 do not collide either
                        * (sorted_slots_msg_sizes[:, a] != 0)
                )
                # mark collision for both agents
                sorted_collisions[:, r] = (sorted_collisions[:, r] + collisions_r_a) > 0
                sorted_collisions[:, a] = (sorted_collisions[:, a] + collisions_r_a) > 0

        collisions = torch.gather(input=sorted_collisions, dim=-1, index=sorted_slot_indices_inverted)

        if self.verbose:
            print("Message sizes\n", agent_msg_sizes)
            print("Chosen slots\n", chosen_slots)
            print("Sorted slots\n", sorted_slots)
            print("Sorted msg sizes\n", sorted_slots_msg_sizes)
            print("Sorted collisions\n", sorted_collisions)
            print("Collisions\n", collisions)

        return collisions

    def __repr__(self):
        return f"{type(self).__name__}_{self.nb_slots}{'_spacing' if self.use_msg_size_spacing else ''}"


class PermutationChannel(Channel):
    """
    Inserts messages into a channel of limited size without gaps according to a random permutation. Messages that
    exceed the channel capacity are dropped.
    """
    def __init__(self, nb_slots, priority_mode='equal', verbose=False):
        """

        :param nb_slots: the number of slots (same unit as message size)
        :param priority_mode: message priority mode in ['equal', 'inverted_size']
        :param verbose: Print debug information
        """
        self.nb_slots = nb_slots
        self.priority_mode = priority_mode
        self.verbose = verbose

    def simulate_transmission(self, agent_msg_sizes: torch.Tensor):
        batch_size = agent_msg_sizes.shape[0]
        nb_agents = agent_msg_sizes.shape[1]

        # get random permutation of agents
        if self.priority_mode == 'equal':
            priority = torch.rand((batch_size, nb_agents))
        elif self.priority_mode == 'inverted_size':
            priority = torch.rand((batch_size, nb_agents)) * 1 / agent_msg_sizes
        else:
            raise ValueError(f"Unknown priority mode {self.priority_mode}")
        _, sorted_agents = torch.sort(priority, dim=-1)

        sorted_slots_msg_sizes = torch.gather(input=agent_msg_sizes, dim=-1, index=sorted_agents)
        sorted_slots_msg_sizes_inverted = sorted_agents.argsort()

        sorted_used_slots = torch.zeros(batch_size, dtype=torch.float)
        sorted_drops = torch.zeros((batch_size, nb_agents), dtype=torch.float)

        for r in range(0, nb_agents):
            # according to the sorted agents (r = 0 is not agent 0 but the one that sends first)
            sorted_used_slots += sorted_slots_msg_sizes[:, r]
            # drop messages if they exceed the capacity, never drop 0-messages
            sorted_drops[:, r] = (sorted_used_slots > self.nb_slots) * (sorted_slots_msg_sizes[:, r] != 0)

        collisions = torch.gather(input=sorted_drops, dim=-1, index=sorted_slots_msg_sizes_inverted)

        if self.verbose:
            print("Message sizes\n", agent_msg_sizes)
            print("Priority\n", priority)
            print("Message Order\n", sorted_agents)
            print("Sorted msg sizes\n", sorted_slots_msg_sizes)
            print("Sorted drops\n", sorted_drops)
            print("Collisions\n", collisions)

        return collisions

    def __repr__(self):
        return f"{type(self).__name__}_{self.nb_slots}_{self.priority_mode}"


def get_random_message_sizes(batch_size, nb_agents, msg_size):
    """
    Generate random message sizes for each agent (dim: batch_size x nb_agents)

    :param batch_size: Batch size
    :param nb_agents: Number of agents
    :param msg_size: Tuple of valid message sizes
    :return: A tensor of random message sizes for each agent in the batch
    """
    indices = torch.randint(low=0, high=len(msg_size), size=(batch_size, nb_agents))
    return torch.tensor(msg_size)[indices]


def approximate_drops_and_util(nb_agents, msg_size, channel: Channel, num_tests=1_000_000, drop_out='absolute'):
    """
    Approximate the mean number of drops and channel utilisation for the given parameters
    when choosing the message sizes randomly.

    :param nb_agents: Number of agents
    :param msg_size: Tuple of valid message sizes
    :param channel: Channel
    :param num_tests: Number of tests (= batch size)
    :param drop_out: 'ratio': mean percentage of dropped messages, 'absolute': avg number of dropped messages,
                     'ratio_per_size': percentage of dropped messages per message size
    :return: drop ratio, channel utilization
    """
    agent_msg_sizes = get_random_message_sizes(num_tests, nb_agents, msg_size)
    dropped = channel.simulate_transmission(agent_msg_sizes)
    dropped *= (agent_msg_sizes != 0)

    if drop_out == 'ratio':
        drop_result = dropped.sum() / torch.clip((agent_msg_sizes != 0).sum(), min=1)
    elif drop_out == 'absolute':
        drop_result = dropped.sum() / num_tests
    elif drop_out == 'ratio_per_size':
        drops_per_size = np.zeros(len(msg_size))
        total_per_size = np.zeros(len(msg_size))
        for i, size in enumerate(msg_size):
            drops_per_size[i] = ((agent_msg_sizes == size) * dropped).sum()
            total_per_size[i] = (agent_msg_sizes == size).sum()

        drop_result = drops_per_size / np.clip(total_per_size, a_min=1, a_max=None)
    else:
        raise ValueError(f"Unknown drop_out mode {drop_out}")

    # we could also exclude 0-messages from the drop ratio definition
    # drop_ratio = torch.nanmean((dropped * (agent_msg_sizes > 0)).sum(dim=-1)
    # / (agent_msg_sizes > 0).sum(dim=-1).float())
    channel_util = ((1 - dropped) * agent_msg_sizes).sum(dim=-1).float().mean()
    return drop_result, channel_util


def plot_drop_ratio_over_msg_size(nb_agents, max_msg_size_list, channels: List[Channel]):
    ratios = collections.defaultdict(list)
    channel_utils = collections.defaultdict(list)
    print("Creating drop ratio & utilization plots..")
    for channel in tqdm(channels):
        for max_size in max_msg_size_list:
            drop_ratio, channel_util = approximate_drops_and_util(nb_agents, tuple(list(range(0, max_size))), channel,
                                                                  drop_out='ratio')
            ratios[channel].append(drop_ratio)
            channel_utils[channel].append(channel_util)

    f, (ax1, ax2) = plt.subplots(2, 1)
    f.suptitle(f"Channel stats with {nb_agents} agents and message sizes in (0, ..., max_size - 1)", fontsize=16)

    for channel in channels:
        ax1.plot(max_msg_size_list, ratios[channel], label=str(channel))
    ax1.set_xlabel('Max message size')
    ax1.set_ylabel('Drop ratio [%]')
    ax1.grid()
    ax1.legend()

    for channel in channels:
        ax2.plot(max_msg_size_list, channel_utils[channel], label=str(channel))
    ax2.set_xlabel('Max message size')
    ax2.set_ylabel('Channel utilization [slots]')
    ax2.grid()


def main():
    batch_size = 2
    nb_agents = 4
    msg_size = (0, 1, 2, 4, 16)
    nb_slots = 32

    # generate random message sizes for each agent (dim: batch_size x nb_agents)
    agent_msg_sizes = get_random_message_sizes(batch_size, nb_agents, msg_size)

    channel = StochasticChannel(
        nb_slots=nb_slots,
        use_msg_size_spacing=True,
        verbose=True
    )

    # simulate putting these messages into a slotted channel and getting collisions
    channel.simulate_transmission(agent_msg_sizes)

    channels = [
        PerfectChannel(),
        BlockingChannel(),
        StochasticChannel(channel.nb_slots, False),
        StochasticChannel(channel.nb_slots, True),
        PermutationChannel(channel.nb_slots, 'equal', False),
        PermutationChannel(channel.nb_slots, 'inverted_size', False)
    ]

    print(f"\nDrop ratio, channel util for random transmission of msg sizes {msg_size}:")
    for c in channels:
        print(f"> {c}: {approximate_drops_and_util(nb_agents, msg_size, c, drop_out='ratio_per_size')}")

    plot_drop_ratio_over_msg_size(
        nb_agents, list(range(1, nb_slots)), channels
    )

    plt.show()


if __name__ == '__main__':
    main()
