import torch
import torch.nn as nn

class Neuron(nn.Module):
    layer: int = 0
    """
    Base class for Spiking Neurons (SNNs).
    """
    def __init__(self):
        super(Neuron, self).__init__()
        self.spikes: int = None
        self.spike_rate: float = None
        Neuron.layer += 1
        self.count = 0
        self.layer = Neuron.layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass
        
    def initial_state(self, input_tensor: torch.Tensor):
        pass

    def updates_spike_rate(self, spikes: int, total: int) -> None:
        if self.spike_rate is None:
            self.spike_rate = (spikes / total)
        else:
            self.spike_rate += (spikes / total)
            self.spike_rate /= 2

    def updates_spikes(self, spikes: float, total: int) -> None:
        self.updates_spike_rate(spikes, total)
        if self.spikes is None:
            self.spikes = spikes
        else:
            self.spikes += spikes

    def reset_spikes(self) -> None:
        self.spikes = None
        self.spike_rate = None
