from __future__ import division
import torch
import torch.nn as nn
from utils import norm_col_init, weights_init


class AClinear(torch.nn.Module):
    def __init__(self, state_space, action_space, state_dim):
        super(AClinear, self).__init__()

        self.critic_linear = nn.Linear(state_dim, 1, bias=False)
        self.actor_linear = nn.Linear(state_space, action_space, bias=False)

        self.apply(weights_init)
        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)

        self.train()

    def forward(self, inputs):
        state_feature, state_onehot = inputs
        return self.critic_linear(state_feature), self.actor_linear(state_onehot)
