import torch
import torch.nn as nn
from core.networks.mlp import MLP
from core.utils.sac_utils import initialize_hidden_layer, initialize_last_layer


# todo: check this layer intialization does not harm TD3
class ContinuousQNetwork(nn.Module):
    def __init__(self, observation_dim, action_dim, layers_dim):
        super(ContinuousQNetwork, self).__init__()

        # Q1 architecture
        self.q1_mlp = MLP(observation_dim + action_dim, 1, layers_dim)
        layers = self.q1_mlp.layers
        for layer in layers[:-1]:
            initialize_hidden_layer(layer)
        initialize_last_layer(layers[-1])

        # Q2 architecture
        self.q2_mlp = MLP(observation_dim + action_dim, 1, layers_dim)
        layers = self.q1_mlp.layers
        for layer in layers[:-1]:
            initialize_hidden_layer(layer)
        initialize_last_layer(layers[-1])

    def forward(self, x, u):
        xu = torch.cat([x, u], 1)
        x1 = self.q1_mlp(xu)
        x2 = self.q2_mlp(xu)
        return x1, x2

    def Q1(self, x, u):
        xu = torch.cat([x, u], 1)
        return self.q1_mlp(xu)