# DIGL: Distribution Invariant Graph Learning

This repository contains the PyTorch implementation for "DIGL: Distribution Invariant Graph Learning", a framework for robust graph classification with OOD generalization capabilities.

## Environment Requirements

Code is tested with **Python 3.10**. Key requirements:

- python>=3.10
- torch>=2.0.0
- torch_geometric>=2.2.0
- numpy>=1.21.5
- scikit_learn>=1.0.2
- networkx>=2.6.3
- matplotlib>=3.5.0
- tqdm>=4.64.0
- scipy>=1.7.3
- ogb>=1.3.6

Install dependencies with:
```bash
pip install torch torch_geometric numpy scikit-learn networkx matplotlib tqdm scipy ogb
```

## Datasets

### 1. GOOD Datasets (Graph OOD Benchmark)
Includes: GOOD-MOTIF, GOOD-CMNIST, GOOD-SST2, GOOD-HIV
- **Location**: `data/good/` directory
- **Format**: Each dataset contains train.pt, val.pt, test.pt files
- **Source**: Part of the GOOD benchmark for graph OOD generalization
- **Download**: All datasets are also provided in [https://github.com/divelab/GOOD/]

### 2. DisC Datasets (Disentangled Causal Datasets)
Includes: CMNIST, CFASHION, CKUZUSHIJI
- **Location**: `data/disc/` directory  
- **Format**: Each dataset contains train.pkl, val.pkl, test.pkl files
- **Feature**: Color-digit correlation with environmental biases
- **Download**: All datasets are also provided in [https://github.com/googlebaba/DisC]

### 3. Dynamic Generation
- If data files are not found, datasets will be dynamically generated with configurable bias levels
- Supports both image-based (DisC) and graph-based (GOOD) data formats

## Training Scripts

### 1. GOOD Dataset Training
```bash
# Basic GOOD training
python experiments/train_good1.py --dataset good-motif
python experiments/train_good1.py --dataset good-cmnist
python experiments/train_good1.py --dataset good-sst2
python experiments/train_good1.py --dataset good-hiv

# With DIGL-specific features
python experiments/train_good1.py --dataset good-motif --use-wasserstein
python experiments/train_good1.py --dataset good-cmnist --use-mutual-info
python experiments/train_good1.py --dataset good-sst2 --num-envs 4

# Quick testing
python experiments/train_good1.py --dataset good-motif --quick-test
```

### 2. DisC Dataset Training
```bash
# Train on DisC datasets
python experiments/train_disc1.py --dataset cmnist
python experiments/train_disc1.py --dataset cfashion
python experiments/train_disc1.py --dataset ckuzushiji
python experiments/train_disc1.py --dataset all  # Train all three

# With advanced DIGL features
python experiments/train_disc1.py --dataset cmnist --use-wasserstein --use-mutual-info
python experiments/train_disc1.py --dataset cfashion --num-envs 4 --hidden-dim 256

# Quick test mode
python experiments/train_disc1.py --dataset cmnist --quick-test
```

## Command Line Arguments

### Common Arguments for GOOD Training
```
--dataset           Dataset name [good-motif, good-cmnist, good-sst2, good-hiv]
--epochs            Number of training epochs (default: 200)
--batch-size        Batch size (default: 128)
--lr                Learning rate (default: 0.001)
--hidden-dim        Hidden dimension size (default: 64-256)
--num-envs          Number of augmented environments (default: 3)
--lambda-align      Weight for prototype alignment loss (default: 1.0)
--lambda-disentangle Weight for representation disentanglement loss (default: 0.5)
--lambda-class      Weight for classification loss (default: 1.0)
--lambda-adv        Weight for adversarial environment loss (default: 0.1)
--use-wasserstein   Use Wasserstein distance for prototype alignment
--use-mutual-info   Use mutual information for disentanglement
--output-dir        Output directory (default: ./results/digl_good)
--patience          Early stopping patience (default: 10-25)
--seed              Random seed (default: 42)
--device            Device to use [cuda, cpu] (default: cuda if available)
--quick-test        Quick test mode with reduced epochs
--alternating-steps Alternating optimization steps per epoch (default: 3-5)
--data-path         Path to GOOD dataset files (default: ./data/good)
```

### Common Arguments for DisC Training
```
--dataset           Dataset name [cmnist, cfashion, ckuzushiji, all]
--epochs            Number of training epochs (default: 200)
--batch-size        Batch size (default: 128)
--hidden-dim        Hidden dimension size (default: 128-256)
--lr                Learning rate (default: 0.001)
--num-envs          Number of augmented environments (default: 3)
--lambda-align      Weight for prototype alignment loss (default: 1.0)
--lambda-disentangle Weight for representation disentanglement loss (default: 0.5)
--lambda-class      Weight for classification loss (default: 1.0)
--lambda-adv        Weight for adversarial environment loss (default: 0.1)
--use-wasserstein   Use Wasserstein distance for prototype alignment
--use-mutual-info   Use mutual information for disentanglement
--output-dir        Output directory (default: ./results/digl_disc)
--patience          Early stopping patience (default: 15-20)
--seed              Random seed (default: 42)
--device            Device to use [cuda, cpu] (default: cuda if available)
--quick-test        Quick test mode with reduced epochs
--color-bias        Color bias strength for training set (default: 0.9)
--img-size          Image size (default: 28)
```