import torch
import torch.nn as nn
import torch.nn.functional as F


class QNetwork(nn.Module):
    """
    Generic Q network architecture
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state):
        features1 = F.relu(self._h1(state))
        features2 = F.relu(self._h2(features1))
        q = self._h3(features2)
        return q


class RecurrentQNetwork(nn.Module):
    """
    Recurrent Q network architecture using GRU
    """

    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self.gru = nn.GRU(n_features, n_features)
        self._h2 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state, hidden=None, output_hidden=False, output_all=False):
        """
        state: tensor of shape ((time), batch_size, n_features)
        hidden: tensor of shape (1, batch_size, n_features)
        output_hidden: bool indicating whether to output the hidden state
        output_all: bool indicating whether to output q values and hidden states for all timesteps
        """

        features1 = F.relu(self._h1(state))
        gru_out, hidden_new = self.gru(features1, hidden)

        if output_all:
            q = self._h2(gru_out)
            if output_hidden:
                return q, gru_out
            else:
                return q
        else:
            q = self._h2(hidden_new)
            if output_hidden:
                return q, hidden_new
            else:
                return q


class CEMNEtwork(nn.Module):
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        n_features_h1 = n_features[0]
        n_features_h2 = n_features[1]

        self._h1 = nn.Linear(n_input, n_features_h1)
        self._h2 = nn.Linear(n_features_h1, n_features_h2)
        self._h3 = nn.Linear(n_features_h2, n_output)

        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain("linear"))

    def forward(self, state, action):
        state_action = torch.cat([state, action.contiguous()], dim=-1)
        features1 = F.relu(self._h1(state_action))
        features2 = F.relu(self._h2(features1))
        q = self._h3(features2)
        return q
