from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import norm_col_init, weights_init

class ACMLP(torch.nn.Module):
    def __init__(self, action_space, state_dim):
        super(ACMLP, self).__init__()

        self.linear1 = nn.Linear(state_dim, 128, bias=True)
        self.linear2 = nn.Linear(128, 128, bias=True)
        self.critic_linear = nn.Linear(128, 1, bias=True)
        self.actor_linear = nn.Linear(128, action_space, bias=True)

        self.apply(weights_init)
        self.linear1.weight.data = norm_col_init(
            self.linear1.weight.data, 0.01)
        self.linear1.bias.data.fill_(0)
        self.linear2.weight.data = norm_col_init(
            self.linear2.weight.data, 0.01)
        self.linear2.bias.data.fill_(0)
        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.train()

    def forward(self, inputs):
        feature = inputs
        x = torch.sigmoid(self.linear1(feature))
        x = torch.sigmoid(self.linear2(x))
        return self.critic_linear(x), self.actor_linear(x)
    
    def actor_forward(self, inputs):
        feature = inputs
        x = torch.sigmoid(self.linear1(feature))
        x = torch.sigmoid(self.linear2(x))
        return self.actor_linear(x)
