# SNN Fault-Tolerance and Hardware Co-Design Framework

Here is a `README.md` file to help the users who run files in our supplementary material.

-----

# SNN Fault-Tolerance and Hardware Co-Design Framework

This project is a PyTorch-based research framework for Spiking Neural Networks (SNNs). It focuses on simulating faults in hardware-implemented SNNs, benchmarking fault-tolerance techniques, and providing pathways for hardware (VHDL) export.

The framework is built using the `spikingjelly` library for SNN components.

## 🚀 Core Features

  * **SNN Models:** Includes SNN implementations for a simple MLP (`simple_snn.py`), VGG (`vgg_snn.py`), and ResNet (`resnet_snn.py`).
  * **Advanced Fault Injection:** `fault_injection.py` provides a `FaultManager` to simulate various hardware faults:
      * **Types:** `stuck` (stuck-at-0/1), `random` (additive noise), `connectivity` (random weights).
      * **Distributions:** `sporadic` (randomly distributed) and `clustered` (affecting entire neurons/rows).
  * **Input Fragmentation:** A novel input processing technique (`fragmentation.py`) that dynamically or manually splits an input image into a sequence of temporal "fragments" for the SNN.
  * **Fault-Tolerance Benchmarks:** Includes built-in (though optional) hooks for various tolerance techniques, such as:
      * `ECOC` (Error-Correcting Output Codes)
      * `Routing` (dynamic routing around faulty units)
      * `LIFA`, `Astrocyte`, `Falvolt` (other mitigation methods)
  * **VHDL Export:** `vhdl_simple_snn.py` demonstrates training a simple SNN and exporting its architecture and weights to VHDL for hardware synthesis using the `spikerplus` library.
  * **In-Depth Metrics Wrapper:** `addon_metrics_wrapper.py` is a powerful tool that wraps any training script to log detailed, per-iteration statistics (e.g., gradient norms, weight saturation, neuron activation 'z' values) to a CSV and generate plots.

## 🛠️ Setup

This project requires Python and several packages.

**Core Dependencies:**

  * `torch` (PyTorch)
  * `spikingjelly`
  * `numpy`
  * `scikit-learn` (for metrics)
  * `matplotlib`, `seaborn` (for plotting)

**For VHDL Export:**

  * `spikerplus` (This library must be provided separately, see `vhdl_simple_snn.py`)

You can install most dependencies via pip:

```bash
pip install torch torchvision spikingjelly numpy scikit-learn matplotlib seaborn
```

You can also refer to the paper (Spiker+: a framework for the generation of efficient Spiking Neural Networks FPGA accelerators for inference at the edge
) that we used to make VHDL files.

## 🏃 How to Run

All main scripts (`simple_snn.py`, `vgg_snn.py`, `resnet_snn.py`) are configured using command-line arguments.

### 1\. Standard Training

To run a basic training of the simple MLP on MNIST:

```bash
python simple_snn.py --data_path ./propdata/MNIST --num_epochs 10
```

To run a ResNet-34 on CIFAR-100:

```bash
python resnet_snn.py --data_path ./propdata/CIFAR100 --resnet_depth 34 --num_epochs 50
```

### 2\. Training with Faults and Fragmentation

The true power comes from combining arguments. This command trains a VGG-SNN on CIFAR-10 with 20% sporadic stuck-at faults (injected after epoch 5) and using the "fragmentation" input method.

```bash
python vgg_snn.py \
    --data_path ./propdata/CIFAR10 \
    --num_epochs 50 \
    --num_steps 4 \
    --Fault True \
    --fault_type stuck \
    --fault_dist sporadic \
    --fault_ratio 0.2 \
    --fault_start_epoch 5 \
    --Frag True
```

### 3\. Using the Metrics Wrapper

To get detailed logs on internal network state (like weight saturation and gradients), use the `addon_metrics_wrapper.py`. It runs the target script for you.

**Note the `-- \` syntax** which separates wrapper arguments from the target script's arguments.

```bash
python addon_metrics_wrapper.py \
    --metrics_csv "mnist_fault_run.csv" \
    --plot True \
    --fig_dir "run_figures" \
    -- \
    simple_snn.py \
        --data_path ./propdata/MNIST \
        --num_epochs 15 \
        --Fault True \
        --fault_ratio 0.3
```

This will run `simple_snn.py` and produce `mnist_fault_run.csv` and a folder `run_figures` with plots.

### 4\. Exporting to VHDL

The `vhdl_simple_snn.py` script trains a model and then exports it. You must provide the path to your local `spiker_public` repository.

```bash
python vhdl_simple_snn.py \
    --data_path ./propdata/MNIST \
    --num_epochs 5 \
    --vhdl_out ./my_snn_vhdl \
    --spiker_root ../path/to/spiker_public
```

Please set vhdl_gen.generate() to vhdl_gen.generate(interface = **True**, functional = False). We also recommend that you use Xillinx Vivado to operate our VHDL code. 

-----

## 📂 File Overview

  * **Main Training Scripts:**
      * `simple_snn.py`: A 4-layer MLP SNN for MNIST/FMNIST/UCI-HAR.
      * `vgg_snn.py`: A VGG-style SNN for CIFAR.
      * `resnet_snn.py`: A ResNet-style SNN for CIFAR/ImageNet.
  * **Core Modules:**
      * `fault_injection.py`: The `FaultManager` class for simulating all hardware faults.
      * `fragmentation.py`: Implements the dynamic and manual input fragmentation logic.
      * `utils.py`: Contains helper classes like `TDBatchNorm` (Time-Domain Batch Norm) and `ZBiasAdder`.
  * **Tools & Export:**
      * `vhdl_simple_snn.py`: Trains a simple SNN and exports it to VHDL.
      * `addon_metrics_wrapper.py`: A powerful wrapper for logging and plotting internal model statistics.
  * **Supporting Files (Imported by others):**
      * `benchmarks.py`: (Not provided, but imported) Contains hooks for fault-tolerance methods (ECOC, LIFA, etc.).

