
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class PolicyVAR(nn.Module):  
    def __init__(self,inin,out):  
        super(PolicyVAR, self).__init__()  
        self.affine1 = nn.Linear(inin, 128) 
        self.affine2 = nn.Linear(128, 128) 
        self.affine3 = nn.Linear(128, out) 
        self.output_activation = lambda x: -torch.nn.functional.relu(x)
        self.logTemperature = torch.nn.Parameter(1.54 * torch.ones(1), requires_grad=True)
    def forward(self, x):  
        x = torch.tanh(self.affine1(x))
        x = torch.tanh(self.affine2(x))
        x = self.affine3(x) + self.logTemperature
        x = self.output_activation(x)
        return x
