import torch
from torch import nn

class LearnableMembrane(nn.Module):
    def __init__(self):
        """
        A generic class for learnable membrane function for SNNs.
        """
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass

    @torch.no_grad()
    def plot(self, x: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Plot the membrane function.
        """
        pass

    def print_parameters(self):
        """
        Print the parameters of the membrane function.
        """
        pass