# Dynamic-Neural-Graph

This folder contains experiments from “Dynamic Neural Graph: Facilitating Temporal Dynamics Learning in Deep Weight Space”.

## Setup

Before running the code, ensure that your current working directory is the root directory of the repository. Use the following commands to set up the virtual environment:

```bash
conda env create -f environment.yml
conda activate dng
```

## Generate Dynamic Neural Graph Data

First, generate the Dynamic Neural Graph data for INRs by following these steps:

1. **Download Pre-trained SIREN Weights:**

   - Download SIREN weights trained on the MNIST, FashionMNIST, and CIFAR-10 datasets using [this link](https://drive.google.com/drive/folders/15CdOTPWHqDcS4SwbIdm100rXkTYZPcC5?usp=sharing).
   - Download SIREN weights trained on the CIFAR-100 dataset using [this link](https://drive.google.com/drive/folders/1TwUZmcE2XrGQXCPhGAIa_sXCd5kX8OIA?usp=sharing).

2. Ensure that the downloaded weights are unzipped in the `./data` directory.

3. Run the following command to generate the Dynamic Neural Graph data:

   ```bash
   python generate_dng_data.py --ds [dataset]
   ```

   Where `dataset` can be one of the following options:
   - `mnist` (for MNIST INR dataset)
   - `fashion` (for FashionMNIST INR dataset)
   - `cifar` (for CIFAR-10 INR dataset)
   - `cifar100` (for CIFAR-100 INR dataset)

## INR2JLS and INR Classification

To classify INRs using the INR2JLS framework, we first convert the INRs into latent feature maps through the INR2JLS framework, and then perform classification based on these latent feature maps.

### Train the INR2JLS Framework

Run the following command to train the INR2JLS framework and save the DNG-Encoders used to generate latent feature maps:

```bash
python dng_inr2jls.py --aug --ds [dataset] --l-size [latent size]
```

Where:
- `dataset` can be one of `mnist`, `fashion`, `cifar`, or `cifar100`.
- `latent size` represents the size of the latent feature maps (H * W). In our experiments:
  - For MNIST INR dataset and FashionMNIST INR dataset, set it to `49`.
  - For CIFAR-10 INR dataset and CIFAR-100 INR dataset, set it to `64`.

### Generate Latent Feature Maps and Classify INRs

Once you have the pre-trained DNG-Encoder, use it to generate latent feature maps and classify the INRs:

```bash
python dng_latent_classify.py --aug --ds [dataset] --l-size [latent size] --enc-dir [encoder directory]
```

Where:
- `dataset` and `latent size` should match the previous choices.
- `encoder directory` is the folder containing the pre-trained DNG-Encoder, e.g., `dng_encoder_models_jls/mnist/2024_01_01_12_00_00`.

### INR Classification using DNG-Encoder Only

Alternatively, you can directly classify the INRs using the DNG-Encoder without generating latent feature maps:

```bash
python dng_enc_classify.py --aug --ds [dataset]
```

## Editing INRs

You can stylize the INRs using the following command:

```bash
python dng_stylize_siren.py --aug --ds [dataset] --style [style]
```

Where:
- `dataset` can be `mnist` or `fashion`.
- `style` can be one of the following options: `dilate`, `erode`, `gradient`.

## Predicting Generalization

### **For CNN Models**

To predict the generalization of CNN classifiers, you should first ensure that the `cifar10` or `svhn` datasets are placed under the directory `./data/predict_gen`. After this, you can predict the generalization of CNN classifiers using the following command:

```bash
python dng_predict_gen.py --ds [dataset] --sigmoid
```

Where:
- `dataset` can be `cifar10` or `svhn`.

### **For ViT Models**

We also provide an experiment for predicting the generalization of Vision Transformer (ViT) models. 

1. **Download the Dataset**  
   We will release the dataset once the paper is accepted.
   Ensure that the dataset is placed in the `./data` directory.

2. **Generate ViT Graph Data**  
   After placing the dataset in the correct directory, run the following command to generate the ViT graph data:  

   ```bash
   python generate_dng_data.py --type vit
   ```

3. **Predict the Generalization**  
   Finally, to predict the generalization of the ViT models, use the following command:

   ```bash
   python dng_predict_gen_vit.py --sigmoid
   ```