import torch
import torch.nn as nn

class Neuron(nn.Module):
    layer: int = 0
    """
    Base class for Spiking Neurons (SNNs).
    """
    def __init__(self):
        super(Neuron, self).__init__()
        self.spikes: int = 0
        self.total: int = 0
        self.num_neurons: int = 0
        self.layer = Neuron.layer
        Neuron.layer += 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass
        
    def initial_state(self, input_tensor: torch.Tensor):
        pass

    def updates_spikes(self, spikes: float, total: int) -> None:
        self.spikes += spikes
        self.total += total

    def reset_spikes(self) -> None:
        self.spikes = 0
        self.total = 0
