from typing import Tuple, Optional

import torch
from torch import nn
from modules.dru.dru import DRUnit
from util import epsilon_greedy
from modules.pseudo_step import PseudoStep
from constants import *


class MessageEncoder(nn.Module):
    """
    Generates messages for all agents.
    """

    def __init__(self, n_agents: int, features_in: int, msg_sizes: Tuple, message_mode: str):
        """
        Initializes the encoder.

        :param n_agents: Number of agents
        :param features_in: Number of input features
        :param msg_sizes: The size of each generated message

        """
        super().__init__()
        self.n_agents = n_agents
        self.features_in = features_in
        self.msg_sizes = msg_sizes
        self.message_mode = message_mode

        def get_output_size(size):
            if self.message_mode == DISCRETE:
                return 2 ** size
            else:
                return size

        shared_encoder_features_out = self.features_in # max([get_output_size(size) for size in self.msg_sizes])
        self.msg_encoder_shared = nn.Sequential(
            nn.Linear(self.features_in, shared_encoder_features_out),
            nn.Tanh()
        )

        # ModuleList is required to register the modules (e.g. for .to(device))
        self.msg_encoder_heads = nn.ModuleList(
            [
                None if (size == 0 or message_mode == ZEROS) else nn.Sequential(
                    nn.Linear(shared_encoder_features_out, get_output_size(size)),
                    nn.Tanh()
                ) for size in self.msg_sizes
            ]
        )

        if self.message_mode == DRU:
            self.dru = DRUnit(sigma=2)

        if self.message_mode == PSEUDOGRADIENT:
            self.pseudo_step = PseudoStep.apply

        self.mask_len = self.get_mask_len(self.msg_sizes)
        self.out_len = self.get_msg_container_len(self.msg_sizes)

    @staticmethod
    def get_mask_len(msg_sizes):
        if msg_sizes is None or len(msg_sizes) == 0:
            return 0

        if msg_sizes[0] == 0:
            return len(msg_sizes) - 1
        else:
            return len(msg_sizes)

    @staticmethod
    def get_msg_container_len(msg_sizes):
        """
        Get the length of the message container consisting of message and mask.

        :param msg_sizes: The message sizes
        :return: The length of the message containers
        """
        if msg_sizes is None:
            return 0

        return max(msg_sizes) + MessageEncoder.get_mask_len(msg_sizes)

    def forward(self, x: torch.Tensor, msg_size_ind=None, discrete_msg_epsilon=0):
        x = self.msg_encoder_shared(x)

        # shape (x, y): messages from agents in dim x with content in dim y
        msg_list = torch.zeros((x.shape[0], self.out_len), dtype=torch.float, device=x.device)
        if self.message_mode == DISCRETE:
            discrete_msg_selected_q = torch.zeros((x.shape[0], 1), dtype=torch.float, device=x.device)
        else:
            discrete_msg_selected_q = None

        for i, size in enumerate(self.msg_sizes):
            if len(self.msg_sizes) == 1:
                agents_with_this_index = torch.ones((x.shape[0], 1), dtype=torch.bool, device=x.device)
            else:
                agents_with_this_index = (msg_size_ind == i).unsqueeze(1)

            head = self.msg_encoder_heads[i]
            # skip head for empty message
            if head is not None:
                # encode hidden using tail i (output of size to (x, msg_size[i]))
                if self.message_mode == DISCRETE:
                    msg_q_values: torch.Tensor = head(x)
                    indices = epsilon_greedy(msg_q_values, epsilon=discrete_msg_epsilon, device=x.device)
                    selected_q = msg_q_values.gather(dim=-1, index=indices.unsqueeze(-1))
                    # save selected Q values (will be returned)
                    discrete_msg_selected_q[:] = ~agents_with_this_index * discrete_msg_selected_q[:] \
                                                    + agents_with_this_index * selected_q
                    # empty is ok because we assign everything
                    msg = torch.empty((x.shape[0], size), dtype=torch.float, device=x.device)
                    # convert decimal (indices) to binary (message)
                    for a in range(size):
                        remainder = torch.remainder(indices, 2)
                        msg[:, a] = remainder
                        indices = indices.div(2, rounding_mode='floor')
                else:
                    assert self.message_mode == CONTINUOUS or self.message_mode == DRU or \
                           self.message_mode == PSEUDOGRADIENT
                    msg = head(x)

                    if self.message_mode == DRU:
                        msg = self.dru.forward(msg)

                    if self.message_mode == PSEUDOGRADIENT:
                        msg = self.pseudo_step(msg)

                if len(self.msg_sizes) > 1:
                    if size < max(self.msg_sizes):  # padding to have the same length all the time
                        padding = torch.zeros((x.shape[0], max(self.msg_sizes) - msg.shape[1]), device=x.device)
                        msg = torch.cat([msg, padding], dim=1)
            else:
                # create an empty message
                msg = torch.zeros((x.shape[0], max(self.msg_sizes)), device=x.device)

            if self.mask_len > 0:
                # masking to indicate which msg size was selected and have meaningful info
                mask = torch.zeros((x.shape[0], self.mask_len), dtype=torch.float, device=x.device)
                if not (i == 0 and self.msg_sizes[0] == 0):
                    mask[:, i if self.msg_sizes[0] != 0 else i - 1] = 1
                msg = torch.cat([msg, mask], dim=1)

            # the gradients are 0 if message size not selected
            msg_list[:] = ~agents_with_this_index * msg_list[:] + agents_with_this_index * msg

        return msg_list, discrete_msg_selected_q


class FlattenMessageDecoder(nn.Module):
    """
    Flatten all input messages.
    """

    def __init__(self, n_agents: int, msg_len: int):
        """
        Initializes the decoder.

        :param n_agents: Number of agents
        :param msg_len: The length of the messages -> expected input shape (batch_size, n_agents, msg_len)
        """
        super().__init__()
        self.n_agents = n_agents
        self.msg_len = msg_len

    def forward(self, msg_in: torch.Tensor):
        assert msg_in.shape[1] == self.n_agents and msg_in.shape[2] == self.msg_len, "Invalid message shape!"

        # just flatten the messages of all agents
        return msg_in.view(msg_in.shape[0], msg_in.shape[1] * msg_in.shape[2])

    def get_features_out(self) -> int:
        """
        Number of output features after decoding the messages.

        :return: Number of output features.
        """
        return self.n_agents * self.msg_len


class AggregateMessageDecoder(nn.Module):
    """
    Aggregate all input messages with embeddings.
    """

    def __init__(self, n_agents: int, msg_len: int, embedding_len: Optional[int], mode):
        """
        Initializes the decoder.

        :param n_agents: Number of agents
        :param msg_len: The length of the messages -> expected input shape (batch_size, n_agents, msg_len)
        :param embedding_len: The length of the embeddings. If None, no embedding is used.
        :param mode: Specifies how messages are aggregated. Available modes are [sum, mean]
        """
        super().__init__()
        self.n_agents = n_agents
        self.msg_len = msg_len
        self.embedding_len = embedding_len

        if self.embedding_len:
            self.embedding_function = nn.Sequential(
                nn.Linear(msg_len, embedding_len),
                nn.ReLU()
            )

        if mode == 'sum':
            self.forward = self.forward_sum
        elif mode == 'mean':
            self.forward = self.forward_mean
        else:
            raise ValueError(f"Unknown mode {mode}")

    def forward_sum(self, msg_in: torch.Tensor):
        """
        Sum messages (masked embeddings) over agent dimension.

        :param msg_in: Incoming messages (batch_dim, agent_dim, msg_dim)
        :return: aggregated messages
        """
        if self.embedding_len:
            message_mask = (msg_in != 0).sum(dim=-1) > 0
            return self.embedding_function(message_mask.unsqueeze(-1) * msg_in).sum(dim=-2)

        # messages that are not embedded need no mask because nonexistent messages are 0
        return msg_in.sum(dim=-2)

    def forward_mean(self, msg_in: torch.Tensor):
        """
        Mean of non-zero messages (masked embeddings) over agent dimension.

        :param msg_in: Incoming messages (batch_dim, agent_dim, msg_dim)
        :return: aggregated messages
        """
        # mask of valid messages (1 for valid, 0 for invalid)
        message_mask = (msg_in != 0).sum(dim=-1) > 0

        if self.embedding_len:
            # create message embeddings
            embeddings = self.embedding_function(msg_in)

            # masked sum over the agent dimension
            msg_sum = torch.sum(message_mask.unsqueeze(-1) * embeddings, dim=-2)
        else:
            msg_sum = msg_in.sum(dim=-2)

        # divide msg sum by message count
        # min=1.0 to handle the case with 0 incoming messages, the mean is defined as 0 in that case
        clamped_message_count = torch.clamp(message_mask.sum(dim=-1), min=1.0)
        mean = msg_sum / clamped_message_count.unsqueeze(-1)
        return mean

    def get_features_out(self) -> int:
        """
        Number of output features after decoding the messages.

        :return: Number of output features.
        """
        if self.embedding_len:
            return self.embedding_len

        return self.msg_len


def get_message_sizes(message_actions: Optional[torch.Tensor], msg_size: Tuple[int]) -> torch.Tensor:
    """
    Map a tensor of message actions to the corresponding message sizes.

    :param message_actions: A tensor with message actions
    :param msg_size: The non-empty message sizes
    :return: A tensor of selected message sizes when using the message actions. If no message actions are given,
             returns the first message size in a single element tensor.
    """
    if message_actions is None:
        return torch.ones(1) * msg_size[0]

    message_sizes = torch.zeros_like(message_actions)
    for idx, size in enumerate(msg_size):
        message_sizes += (message_actions == idx) * size

    return message_sizes


def empty_message_tensor(n_agents, msg_sizes, batch_size) -> torch.Tensor:
    """
    Get a tensor with empty messages for all agents.

    :param n_agents: Number of agents
    :param msg_sizes: Message sizes
    :param batch_size: Batch size

    :returns: A tensor filled with zeroes which can be plugged in as the initial message.
    """
    return torch.zeros((n_agents, batch_size, MessageEncoder.get_msg_container_len(msg_sizes)), dtype=torch.float)
