
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class PolicyTWOVAR(nn.Module):  
    def __init__(self,inin,out): 
        super(PolicyTWOVAR, self).__init__()  
        self.affine1 = nn.Linear(inin, 128)  
        self.affine2 = nn.Linear(128, 128)  
        self.affine3 = nn.Linear(128, out)  
        self.output_activation = lambda x: -torch.nn.functional.relu(x)
        self.logTemperature = torch.nn.Parameter(1.54 * torch.ones(1), requires_grad=True)
    def forward(self, state, formula, operator):
        state = torch.cat([state, formula,operator], dim=1)
        x = torch.tanh(self.affine1(state))
        x = torch.tanh(self.affine2(x))
        x = self.affine3(x) + self.logTemperature
        x = self.output_activation(x)
        return x
