from typing import Optional

from constants import *
import torch
import torch.nn as nn
import torch.nn.functional

from modules.messages import MessageEncoder, FlattenMessageDecoder, AggregateMessageDecoder
from util import epsilon_greedy


class MNISTFeatureExtractor(nn.Module):
    """
    Maps MNIST input planes to feature vectors.
    """

    def __init__(self, dummy_input, features_dim):
        super(MNISTFeatureExtractor, self).__init__()
        self.features_dim = features_dim

        assert len(dummy_input.shape) == 4 and dummy_input.shape[1] == 1, (
            f"Input shape error. Expected Bx1xHxW, got {dummy_input.shape}"
        )

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, (3, 3), (1, 1)),
            nn.ReLU(),
            nn.Conv2d(16, 32, (3, 3), (1, 1)),
            nn.MaxPool2d(2),
            nn.Flatten(),
        )

        with torch.no_grad():
            flattened_size = self.cnn(dummy_input).shape[-1]

        self.dense = nn.Sequential(
            nn.Linear(flattened_size, features_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.dense(self.cnn(x))


class TrafficJunctionFeatureExtractor(nn.Module):
    """
    Maps Traffic Junction input vector to feature vectors.
    """
    def __init__(self, dummy_input, features_dim):
        super(TrafficJunctionFeatureExtractor, self).__init__()

        assert len(dummy_input.shape) == 3, f"Input shape error. Expected BxL, got {dummy_input.shape}"

        self.net = nn.Sequential(
            nn.Linear(dummy_input.shape[-1], features_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)


class DuelingQPrediction(nn.Module):
    """
    Predicts Q values using the "dueling" DQN architecture.

    Q(s, a) = V(s) + A(s, a) - 1/|A| * sum_{a'}(A(s, a'))

    See "Dueling Network Architectures for Deep Reinforcement Learning" by Wang et al.
    https://arxiv.org/pdf/1511.06581.pdf
    """
    def __init__(self, in_features, num_actions):
        super(DuelingQPrediction, self).__init__()
        self.v = nn.Linear(in_features, 1)
        self.a = nn.Linear(in_features, num_actions)

    def forward(self, x):
        return self.v(x) + (self.a(x) - self.a(x).mean())


class AdaComm(nn.Module):
    def __init__(self, env_name: str, dummy_input: torch.Tensor, n_actions: int, n_agents: int, use_location_input: bool,
                 msg_sizes: Optional[tuple] = None, msg_decode_embedding_len: Optional[int] = None,
                 msg_decode_mode: str = None, n_features: int = 128, message_mode: str = None,
                 msg_size_selection='egreedy', receive_own_message: bool = False,
                 recurrent: bool = False, with_policy_head: bool = False):
        """
        Agent module with adaptive communication.

        :param env_name: Environment name (used to create the correct observation decoder)
        :param dummy_input: Dummy input to calculate the dimension of the CNN part
        :param n_actions: The number of actions (Q-value outputs)
        :param n_agents: The number of agents in the environment
        :param use_location_input: Whether to append the location of the agent to the input
        :param msg_sizes: The sizes of the individual messages
        :param msg_decode_embedding_len: The length of the message embeddings (should be big enough to summarize all
            incoming messages). None means that no embedding is to be used.
        :param msg_decode_mode: The message decode mode
        :param n_features: The number of output units for the feature extractors (= size of the core network)
        :param message_mode: The selected message mode
        :param msg_size_selection: The message size selection mode in [egreedy, softmax, random].
        :param receive_own_message: Whether agents can receive their own message
        :param recurrent: Whether to use a rnn
        :param with_policy_head: Whether to use a separate policy head with action probabilities for action selection
        """
        super(AdaComm, self).__init__()

        self.n_actions = n_actions
        self.n_agents = n_agents
        self.use_location_input = use_location_input
        self.msg_size_selection = msg_size_selection
        self.receive_own_message = receive_own_message
        self.recurrent = recurrent
        self.with_policy_head = with_policy_head

        # implicitly set whether we want to use messages (error treatment done later)
        self.use_msg = msg_sizes is not None and len(msg_sizes) >= 1
        if self.use_msg and len(msg_sizes) == 1 and msg_sizes[0] == 0:
            # don't use messages when the msg size is 0
            self.use_msg = False

        self.msg_sizes = msg_sizes
        self.adaptive_msg = (len(msg_sizes) > 1) if self.use_msg else False

        if env_name == POMNIST:
            self.feature_extractor = nn.Sequential(
                MNISTFeatureExtractor(dummy_input, n_features),
                nn.Dropout(0.5)
            )
        elif env_name == TRAFFIC_JUNCTION:
            self.feature_extractor = TrafficJunctionFeatureExtractor(dummy_input, n_features)
        else:
            raise ValueError(f"Feature extractor for environment '{env_name}' is not implemented")

        if self.use_msg:
            input_msg_len = MessageEncoder.get_msg_container_len(msg_sizes)
            if msg_decode_mode == CONCAT:
                self.msg_decoder = FlattenMessageDecoder(n_agents=n_agents, msg_len=input_msg_len)
            elif msg_decode_mode == AGGREGATE_MEAN:
                self.msg_decoder = AggregateMessageDecoder(n_agents=n_agents, msg_len=input_msg_len,
                                                           embedding_len=msg_decode_embedding_len, mode='mean')
            elif msg_decode_mode == AGGREGATE_SUM:
                self.msg_decoder = AggregateMessageDecoder(n_agents=n_agents, msg_len=input_msg_len,
                                                           embedding_len=msg_decode_embedding_len, mode='sum')
            else:
                raise ValueError(f"Unknown message decode mode {msg_decode_mode}")

        core_dim = n_features + (n_agents if use_location_input else 0) \
                   + (self.msg_decoder.get_features_out() if self.use_msg else 0)

        self.core = nn.Sequential(nn.Linear(core_dim, core_dim), nn.ReLU())

        if self.recurrent:
            self.rnn = nn.GRU(input_size=core_dim, hidden_size=core_dim, num_layers=1)
            # nn.LSTM(input_size=core_dim, hidden_size=core_dim, num_layers=1, batch_first=False)

        if self.with_policy_head:
            self.policy_head = nn.Linear(core_dim, n_actions)
            # predict value of current policy instead of individual q-values
            self.n_val_out = 1
        else:
            self.n_val_out = n_actions

        self.q_head = nn.Linear(core_dim, self.n_val_out)  # DuelingQPrediction(core_dim, self.n_val_out)

        self.q_msg_size = nn.Linear(core_dim, len(msg_sizes)) if self.adaptive_msg else None

        if self.use_msg:
            self.msg_encoder = MessageEncoder(n_agents=n_agents, features_in=core_dim,
                                              msg_sizes=msg_sizes, message_mode=message_mode)

    def forward(self, x: torch.Tensor, location: Optional[torch.Tensor], msg_in: Optional[torch.Tensor] = None,
                msg_size_epsilon: float = 0, discrete_msg_epsilon: float = 0, states=None):
        # assumption: batch, agent, channel, width, height
        five_dims = len(x.shape) == 5
        shape = x.shape

        if five_dims:
            # transform input to 4 dims (combine batch x agent)
            x = x.view(shape[0] * shape[1], *shape[2:])

            if location is not None and self.use_location_input:
                location = location.view(shape[0] * shape[1], self.n_agents)

            if self.use_msg:
                msg_in = msg_in.view(shape[0] * shape[1], msg_in[2:])

        out = self.feature_extractor(x)

        if self.use_location_input:
            out = torch.cat([out, location], dim=1)

        if self.use_msg:
            # (agent_dim, batch_dim, msg_dim) -> (batch_dim, agent_dim, msg_dim)
            msg_in = torch.transpose(msg_in, 0, 1).contiguous()
            out = torch.cat([out, self.msg_decoder(msg_in)], dim=1)

        out = out + self.core(out)

        if self.recurrent:
            # add & remove sequence dimension with only 1 element
            out, states = self.rnn(out.unsqueeze(0), states)
            out = out.squeeze(0)

        q_out = self.q_head(out)
        policy_out = self.policy_head(out) if self.with_policy_head else None

        if self.adaptive_msg:
            q_msg_out = self.q_msg_size(out)

            if self.msg_size_selection == 'softmax':
                if msg_size_epsilon > 0:
                    # boltzmann softmax exploration with temperature msg_epsilon
                    probs = torch.softmax(q_msg_out / msg_size_epsilon, dim=-1)
                    dist = torch.distributions.categorical.Categorical(probs=probs)
                    msg_size_ind = dist.sample().to(device=x.device)
                else:
                    # handle = 0
                    msg_size_ind = q_msg_out.argmax(dim=-1)
            elif self.msg_size_selection == 'egreedy':
                # epsilon-greedy with epsilon = msg_epsilon
                msg_size_ind = q_msg_out.argmax(dim=-1)
                # add randomness to the message selection
                if msg_size_epsilon > 0:
                    mask = torch.rand((q_msg_out.shape[0],), device=x.device) >= msg_size_epsilon
                    msg_size_ind_rand = torch.randint(0, len(self.msg_sizes), (q_msg_out.shape[0],), device=x.device)
                    msg_size_ind = mask * msg_size_ind + ~mask * msg_size_ind_rand
            elif self.msg_size_selection == 'random':
                msg_size_ind = torch.randint(0, len(self.msg_sizes), (q_msg_out.shape[0],), device=x.device)
            else:
                raise ValueError(f"Unknown message size selection mode {self.msg_size_selection}.")

        else:
            # no adaptive messages
            q_msg_out, msg_size_ind = None, None

        msg_out = None
        discrete_msg_selected_q = None
        if self.use_msg:
            msg_out, discrete_msg_selected_q = self.msg_encoder(
                out, msg_size_ind, discrete_msg_epsilon
            )

        if five_dims:
            # transform output back to 5 dims
            q_out = out.view(shape[0], shape[1], *out.shape[1:])

            if self.use_msg:
                msg_out = msg_out.view(shape[0], shape[1], *msg_out.shape[1:])

        return q_out, policy_out, q_msg_out, msg_size_ind, msg_out, discrete_msg_selected_q, states


def joint_inference(agent_model: AdaComm, observations: torch.Tensor, messages: Optional[torch.Tensor] = None,
                    states=None, epsilon=0.0, msg_size_epsilon=0.0, discrete_msg_epsilon=0.0,
                    receive_own_message=False, device=None):
    """
    Get the actions, msg actions and messages (plus additional stats) for all agents by iterating over the first
    dimension of the given observations.

    :param agent_model: The model that is used for all agents
    :param observations: The observations for all agents (num_agents, num_envs, *env_obs_dim)
    :param messages: The messages from all agents (sender, envs, message_dim)
    :param states: States tensor for all agents (Optional) (num_agents, num_envs, *state_dim)
    :param epsilon: The exploration factor
    :param msg_size_epsilon: Epsilon for msg size selection
    :param discrete_msg_epsilon: epsilon for discrete messages
    :param receive_own_message: Whether to receive own messages
    :param device: The torch device used for the observations (should be the same as the one used in the models)
    :return: actions, msg actions (selected sizes), messages, dictionary of used q values
    """
    num_agents = observations.shape[0]
    num_envs = observations.shape[1]
    msg_sizes = agent_model.msg_sizes

    actions = torch.empty((num_agents, num_envs), dtype=torch.long)
    q_values = torch.empty((num_agents, num_envs, agent_model.n_val_out), dtype=torch.float)

    if agent_model.with_policy_head:
        log_pi_a = torch.empty((num_agents, num_envs, agent_model.n_actions), dtype=torch.float)
    else:
        log_pi_a = None

    if agent_model.adaptive_msg:
        q_msg_sizes = torch.empty((num_agents, num_envs, len(msg_sizes)), dtype=torch.float)
        msg_actions = torch.empty((num_agents, num_envs), dtype=torch.long)
    else:
        q_msg_sizes = msg_actions = None

    messages_out = None
    discrete_msg_selected_q_out = None

    if states is None and agent_model.recurrent:
        states = [None] * num_agents

    for i, obs in enumerate(observations):
        model_input = torch.as_tensor(obs, dtype=torch.float)

        agent_location = None
        if agent_model.use_location_input:
            agent_location = torch.zeros(observations.shape[1], observations.shape[0])
            agent_location[:, i] = 1

        if agent_model.use_msg:
            msg_i_in = torch.clone(messages)
            if not receive_own_message:
                # filter own message
                msg_i_in[i, :, :] = 0
        else:
            msg_i_in = None

        if device is not None:
            model_input = model_input.to(device)

            if agent_model.use_location_input:
                agent_location = agent_location.to(device)

            if agent_model.use_msg:
                msg_i_in = msg_i_in.to(device)

        q_values[i], policy_out, q_msg, msg_size_ind, msg_i_out, discrete_msg_i_selected_q, next_states_i = \
            agent_model(model_input, agent_location, msg_in=msg_i_in, msg_size_epsilon=msg_size_epsilon,
                        discrete_msg_epsilon=discrete_msg_epsilon, states=None if states is None else states[i])

        if agent_model.recurrent:
            states[i] = next_states_i

        if agent_model.with_policy_head:
            log_pi_a[i] = torch.log_softmax(policy_out, dim=-1)
            actions[i] = torch.multinomial(torch.softmax(policy_out, dim=-1), 1).squeeze(-1)
        else:
            actions[i] = epsilon_greedy(q_values[i], epsilon=epsilon, device=None)

        if agent_model.adaptive_msg:
            q_msg_sizes[i] = q_msg
            msg_actions[i] = msg_size_ind

        if agent_model.use_msg:
            if messages_out is None:
                # generate messages tensor on the fly
                messages_out = torch.empty((num_agents, *msg_i_out.shape))
                if discrete_msg_i_selected_q is not None:
                    # generate discrete max q tensor on the fly
                    discrete_msg_selected_q_out = torch.empty((num_agents, num_envs, 1))

            messages_out[i] = msg_i_out
            if discrete_msg_i_selected_q is not None:
                discrete_msg_selected_q_out[i] = discrete_msg_i_selected_q

    # dictionary of all used q values
    q_dict = dict(
        q_values=q_values,
        log_pi_a=log_pi_a
    )
    if q_msg_sizes is not None:
        q_dict['q_msg_sizes'] = q_msg_sizes
    if discrete_msg_selected_q_out is not None:
        q_dict['discrete_msg_selected_q_out'] = discrete_msg_selected_q_out

    return actions, msg_actions, messages_out, states, q_dict
