
import torch  
import torch.nn as nn  
import torch.nn.functional as F  

class PolicyOp(nn.Module):  
    def __init__(self,inin,out):  
        super(PolicyOp, self).__init__()  
        self.affine1 = nn.Linear(inin, 128)  
        self.affine2 = nn.Linear(128, 128)  
        self.affine3 = nn.Linear(128, out)
        self.output_activation = lambda x: -torch.nn.functional.relu(x)
        self.logTemperature = torch.nn.Parameter(1.54 * torch.ones(1), requires_grad=True)
    def forward(self, x, fomular): 
        state = torch.cat([x, fomular], dim=1)
        x = torch.tanh(self.affine1(state))
        x = torch.tanh(self.affine2(x))
        x = self.affine3(x) + self.logTemperature
        x = self.output_activation(x)
        return x
