# Learnable Neuron Models (LNM) for Spiking Neural Networks

This repository contains the official implementation of the paper **"Learning Neuron Dynamics within Deep Spiking Neural Networks"** submit to **ICLR 2026**.

## Abstract

Spiking Neural Networks (SNNs) offer a promising energy-efficient alternative to Artificial Neural Networks (ANNs) by utilizing sparse and asynchronous processing through discrete spike-based computation. However, the performance of deep SNNs remains limited by their reliance on simple neuron models, such as the Leaky Integrate-and-Fire (LIF) model, which cannot capture rich temporal dynamics. While more expressive neuron models exist, they require careful manual tuning of hyperparameters and are difficult to scale effectively. This difficulty is evident by the lack of successful implementations of complex neuron models in high-performance deep SNNs.

In this work, we address this limitation by introducing Learnable Neuron Models (LNMs). LNMs are a general, parametric formulation for non-linear integrate-and-fire dynamics that learn neuron dynamics during training. By learning neuron dynamics directly from data, LNMs enhance the performance of deep SNNs. We instantiate LNMs using low-degree polynomial parameterizations, enabling efficient and stable training. We demonstrate state-of-the-art performance in a variety of datasets, including CIFAR-10, CIFAR-100, ImageNet, and CIFAR-10 DVS. LNMs offer a promising path toward more scalable and high-performing spiking architectures.

## Highlights

- **Novel Learnable Neuron Architecture**: Introduction of polynomial-based learnable membrane potential updates
- **Comprehensive Evaluation**: Extensive experiments on CIFAR-10, CIFAR-100, CIFAR-10-DVS, and ImageNet datasets
- **Multiple Model Architectures**: Support for ResNet-19 and VGG architectures with LNM neurons
- **Ablation Studies**: Thorough analysis of polynomial degrees and network configurations

## Requirements

The code requires Python 3.12+ and the following dependencies:

```bash
torch==2.6
norse==1.1.0
tensorboard==2.16.2
tqdm==4.67.1
pandas==2.2.1
numpy==1.26.4
matplotlib==3.8.3
scipy==1.12.0
tonic==1.4.3
recordclass==0.21.1
datasets==3.1.0
```

## Installation

Install dependencies:

```bash
pip install -r requirements.txt
```

### Datasets

CIFAR-10 and CIFAR-100 should automatically download at runtime if they are not already downloaded. ImageNet will automatically download assuming the user of this codebase has a hugging face account setup.

CIFAR-10 DVS needs to be downloaded from https://drive.google.com/file/d/1s2csG5eagX3ZMfFpZCd5d7g8zqJxht4U/view and extracted to the `/data` folder.

## Reproducing Results

### Full Experimental Suite

To reproduce all results from the paper:

1. **Standard LNM Results**:

   ```bash
   sbatch sbatch_scripts/run_lnm.sh
   ```

2. **Ablation Studies**:

   ```bash
   sbatch sbatch_scripts/run_ablation.sh
   ```

### Individual Experiments

Each dataset and configuration has dedicated scripts in the `sbatch_scripts/` directory organized by method and dataset.

### Supported Models

- **LNMResNet19**: ResNet-19 architecture with LNM neurons
- **LNMVGG**: VGG architecture with LNM neurons

### Supported Datasets

- **CIFAR-10**: Standard 32×32 RGB image classification
- **CIFAR-100**: 100-class image classification
- **CIFAR-10-DVS**: Neuromorphic version of CIFAR-10
- **ImageNet**: Large-scale image classification

## Architecture

The LNM neuron model extends the traditional LIF neuron with learnable polynomial membrane dynamics:

```math
v(t+1) = v(t) + f_θ(v(t)) + I(t)
```

Where `f_θ` is a learnable polynomial function parametrized by θ, enabling adaptive membrane potential updates based on the specific task and data characteristics.

### Key Components

- **`src/SNN/Layers/LNM/`**: Core LNM neuron implementation
- **`src/SNN/models/classification/`**: Model architectures (ResNet and VGG)
- **`src/SNN/LearnableMembrane/`**: Learnable membrane potential functions
- **`src/AbstractModels/`**: Training framework and utilities
- **`src/Graphs/Experiments`**: Ablation study graph
- **`src/Graphs/LearnableNeuronUpdate`**: Visualize the learned neuron models of an SNN
- **`src/Graphs/SpikeRateEnergyConsumption`**: Calculate energy consumption of models


## Results

Our method achieves state-of-the-art performance on multiple benchmarks:

| Dataset | Model | Accuracy | Time Steps |
|---------|-------|----------|------------|
| CIFAR-10 | LNMResNet19 | 97.01% | 4 |
| CIFAR-100 | LNMResNet19 | 80.70% | 4 |
| CIFAR-10 DVS | VGGSNN | 82.95% | 10 |
| CIFAR-10 DVS | LNMResNet19 | 81.39% | 10 |
| ImageNet | LNMResNet19 | 70.87% | 4 |

## Project Structure

```str
LNM-submission-ready/
├── src/
│   ├── main.py                    # Main training script
│   ├── AbstractModels/            # Training framework
│   ├── SNN/
│   │   ├── Layers/LNM/           # LNM neuron implementation
│   │   ├── models/               # Neural network architectures
│   │   ├── LearnableMembrane/    # Learnable membrane functions
│   │   └── SurrogateGradients/   # Gradient estimation
│   └── Datasets/                 # Dataset utilities
├── sbatch_scripts/               # Experiment scripts
│   ├── LNM/                     # Standard LNM experiments
│   └── AblationStudy/           # Ablation study scripts
└── requirements.txt             # Dependencies
```

## Command Line Interface

### Basic Training

To train an LNM model with polynomial degree-5 on CIFAR-10 with ResNet-19:

```bash
cd src
python3 main.py \
    --dataset cifar10 \
    --model lnmresnet19 \
    --loss crossentropy \
    --sg rectangle \
    --encoder copy \
    --decode mean \
    --epochs 100 \
    --optimizer sgd \
    --lr 0.1 \
    --l2 1e-4 \
    --momentum 0.9 \
    --seq_length 4 \
    --batch_size 128 \
    --scheduler cosine \
    --poly_degree 5
```

To view all possible command line arguments, run 

```bash
python3 main.py --help
```
