import numpy as np
import torch as th
from torch import nn

class Critic(nn.Module):
    def __init__(self, dim, learning_rate = 1e-6):
        super().__init__()
        self.dim = dim
        input_dim = 2*dim
        self.seq = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )        
        self.optimizer = th.optim.Adam(self.parameters(), lr = learning_rate)

    def forward(self, states_0, states_1):
        inputs = th.cat((states_0, states_1),1)
        return self.seq(inputs)[:,0]
