import torch

class logi_gate(torch.nn.Module):
     def __init__(self, input_dim, output_dim):
         super(logi_gate, self).__init__()
         self.linear = torch.nn.Linear(input_dim, output_dim)
     def forward(self, x):
         outputs = torch.sigmoid(self.linear(x))
         return outputs