# NNiT: Neural Network Diffusion Transformer

This repository contains the implementation of NNiT, a diffusion model designed to generate neural network parameters and architectures.

## Environment Setup

Create a new Conda environment and install the required dependencies:

```bash
conda create -n nnit python==3.11
conda activate nnit

# Install PyTorch
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# Install ManiSkill and other dependencies
pip install --upgrade mani_skill
pip install tensorboard accelerate torchrl tensordict pandas timm wandb
```

---

## Configuration (`configs/`)

The `configs` directory allows you to control experiments without modifying the core code. It is structured by environment/experiment (e.g., `pickcube_config`, `pushcube_config`).

Each config package typically contains:

*   **`config/`**: Directory containing JSON files defining architecture constraints (e.g., `train_arch_4layer.json`, `test_arch_4layer.json`).
*   **`train.yaml`**: Main configuration for the training loop (hyperparameters, dataset paths, logging).
*   **`sample_joint.yaml`**: Configuration for sampling architectures and weights jointly.
*   **`sample_a2w_*.yaml`**: Configuration for conditional generation (Architecture-to-Weight).

### Key Config Parameters
*   `dataset_path`: Path to the HDF5 dataset containing weights and architectures.
*   `model`: Parameters for the NNiT model (hidden size, depth, patch size).
*   `diffusion`: Diffusion process parameters (beta schedule, timesteps).

---

## Source Code (`src/`)

The core logic resides in the `src/` directory.

### `src/dataset.py`
Handles data loading and processing.
*   **`MultiHDF5ArchitectureWeightDataset`**: The primary dataset class.
    *   Loads neural network weights and architectures from HDF5 files.
    *   **Architecture Encoding**: Converts architecture definitions into token sequences (Input -> Hidden Layers -> Output).
    *   **Weight Processing**: Extracts weights/biases, handles quantization/normalization, and prepares them as patches for the transformer.

### `src/model.py`
Defines the neural network architecture.
*   **`NNiT`**: The main class, a **Diffusion Transformer (DiT)** adapted for neural networks.
    *   **Input**: Takes both architecture tokens and weight patches.
    *   **Positional Embeddings**: Uses specialized 3D positional embeddings to represent the structural location of weights within the generated network (Layer, Block, Component).
    *   **DiTBlock**: Standard Diffusion Transformer blocks with adaptive layer normalization (adaLN) for conditioning on timesteps.

### Other Modules
*   **`src/diffusion/`**: Contains the diffusion noise schedulers and sampling logic.
*   **`src/dit_modules.py`**: Helper modules for the transformer (Embedders, Attention blocks).
*   **`src/envs/`**: Environment interfaces (likely for RL evaluation of generated networks).

---

## Usage

### 1. Training

To train the model using **Multiple GPUs** (recommended):

```bash
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config configs/<config_folder>/train.yaml
```

To train with a **Single GPU**:

```bash
accelerate launch --num_processes 1 --mixed_precision fp16 train.py --config configs/<config_folder>/train.yaml
```

### 2. Sampling (Generation)

To generate new networks (Architecture + Weights), configure the appropriate YAML file in your target config folder and run:

```bash
python sample.py --config <config_name>
```

**Output:**
*   Generated samples are saved in `samples/policy_N/`.
*   `architecture.json`: Describes the network structure (Input + MLP layers + Output).
*   `architecture.json` (inside policy folder): Contains the actual generated weights.

### 3. Processing Samples

The script `src/sample_util.py` helps convert the raw generated output back into executable PyTorch models. These can then be evaluated using `HyperPPO/sf_examples/brax/enjoy_difnn.ipynb` to visualize performance (video/reward).

### 4. Testing

To run a full suite of sampling tests:

```bash
python sample.py --config sample_joint && \
python sample_a2w.py --config sample_train && \
python sample_a2w.py --config sample_test
```
