import torch
from .rnn_vanilla import RNN
from .lif_gate import StackedLIFGate, StackedSigmoidGate
from torchkit import pytorch_utils as ptu

class LIFGate(RNN):
    name = "lifgate"
    rnn_class = StackedLIFGate

    def get_zero_internal_state(self, batch_size=1):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size * 4, device=ptu.device)
        

class SigmoidGate(RNN):
    name = "sigmoidgate"
    rnn_class = StackedSigmoidGate

    def get_zero_internal_state(self, batch_size=1):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size * 4, device=ptu.device)