import torch
import torch.nn as nn
import torch.nn.functional as F


class DiscreteActorNetwork(nn.Module):
    """
    Generic actor network architecture
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state):
        features1 = F.relu(self._h1(state))
        features2 = F.relu(self._h2(features1))
        logits = self._h3(features2)

        return logits


class DiscreteCriticNetwork(nn.Module):
    """
    Generic critic network architecture
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state, action):
        state_action = torch.cat((state.float(), action.float()), dim=-1)

        features1 = F.relu(self._h1(state_action))
        features2 = F.relu(self._h2(features1))
        q = self._h3(features2)

        return torch.squeeze(q)


class MADDPGCriticNetwork(nn.Module):
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state, actions):
        """
        state: tensor of shape ((time), batch_size, n_features)
        actions: list of tensors of shape ((time), batch_size, n_actions)
        """
        state_action = torch.cat((state, actions), dim=-1)
        features1 = F.relu(self._h1(state_action))
        features2 = F.relu(self._h2(features1))
        q = self._h3(features2)
        return torch.squeeze(q)


class GRUDiscreteActorNetwork(nn.Module):
    """
    Generic actor network architecture
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self.gru = nn.GRU(n_features, n_features)
        self._h2 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state, hidden=None, output_hidden=False, output_all=False):
        """
        state: tensor of shape ((time), batch_size, n_features)
        hidden: tensor of shape (1, batch_size, n_features)
        output_hidden: bool indicating whether to output the hidden state
        output_all: bool indicating whether to output logits and hidden states for all timesteps
        """
        features1 = F.relu(self._h1(state))
        gru_out, hidden_new = self.gru(features1, hidden)

        if output_all:
            # Get logits for a batch of states
            logits = self._h2(gru_out)
            if output_hidden:
                return logits, gru_out
            else:
                return logits
        else:
            # Get logits for a single state
            logits = self._h2(hidden_new)
            if output_hidden:
                return logits, hidden_new
            else:
                return logits


class ContinuousActorNetwork(nn.Module):
    """
    Generic continuous actor network architecture
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        n_features_h1 = n_features[0]
        n_features_h2 = n_features[1]

        self._h1 = nn.Linear(n_input, n_features_h1)
        self._h2 = nn.Linear(n_features_h1, n_features_h2)
        self._out = nn.Linear(n_features_h2, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._out.weight, gain=nn.init.calculate_gain("tanh"))

        # nn.init.xavier_uniform_(self._h1.weight, gain=0.1)
        # nn.init.xavier_uniform_(self._h2.weight, gain=0.1)
        # nn.init.xavier_uniform_(self._out.weight, gain=0.1)

    def forward(self, state):
        features1 = F.relu(self._h1(state))
        features2 = F.relu(self._h2(features1))
        actions = F.tanh(self._out(features2))
        # actions = self._out(features2)

        return actions


class ContinuousCriticNetwork(nn.Module):
    """
    Generic continuous critic network architecture
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        n_features_h1 = n_features[0]
        n_features_h2 = n_features[1]

        self._h1 = nn.Linear(n_input, n_features_h1)
        self._h2 = nn.Linear(n_features_h1, n_features_h2)
        self._out = nn.Linear(n_features_h2, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._out.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state, action):
        state_action = torch.cat((state.float(), action.float()), dim=-1)

        features1 = F.relu(self._h1(state_action))
        features2 = F.relu(self._h2(features1))
        q = self._out(features2)

        return torch.squeeze(q)
