# Training Diffusion Classifiers with Denoising-Assistance: Score-matching SDEs

This repository heavily builds upon:
- https://github.com/yang-song/score_sde_pytorch
- https://github.com/chen-hao-chao/dlsm

This repository includes the follows:
- Pytorch Lightning implementation of Score-matching SDEs
- Denoising-assistance training and sampling
- Semi-supervised training of CIFAR10, MNIST and SVHN. 
- Positive Unlabeled training of MNIST and SVHN.

In the following, we include some sample commands for illustrating the available command-line-options:
- Noisy CIFAR10 Classifier with full data
>python main.py --workdir workdirs/cifar_noisy --config=configs/ve/cifar10_ncsnpp_deep_continuous.py --config.training.denoise_augment=False --config.data.labels_per_class=-1 --config.training.score_model=False --config.training.clf_model=True --config.training.n_iters=150000
- Denoising Assisted CIFAR10 Classifier with full data. Download the CIFAR10 NCSNpp checkpoints released by [Score-sde-pytorch](https://github.com/yang-song/score_sde_pytorch) amd use that path in the command below:
>python main.py --workdir workdirs/cifar_da --config=configs/ve/cifar10_ncsnpp_deep_continuous.py --config.training.denoise_augment=True --config.data.labels_per_class=-1 --config.training.score_model=False --config.training.clf_model=True --config.training.n_iters=150000 --config.training.score_path="PATH/TO/CKPT"
- Denoising Assisted CIFAR10 Classifier with 400 labels/class
>python main.py --workdir workdirs/cifar_da_400 --config=configs/ve/cifar10_ncsnpp_deep_continuous.py --config.training.denoise_augment=True --config.data.labels_per_class=400 --config.training.score_model=False --config.training.clf_model=True --config.training.n_iters=150000 --config.training.score_path="PATH/TO/CKPT"
- Denoising Assisted MNIST Classifier with 100 labels/class. If you have trained the MNIST score-model and placed it inside score_dirs (described below), the code will automatically load this checkpoint for training.
>python main.py --workdir workdirs/mnist_da_100 --config=configs/ve/mnist_ncsnpp_continuous.py --config.training.denoise_augment=True --config.data.labels_per_class=400 --config.training.score_model=False --config.training.clf_model=True --config.training.n_iters=150000 
- Denoising Assisted MNIST Classifier in Positive Unlabeled Setting with odd digits as the positive class and even digits as the negative class with 1k labeled samples:
>python main.py --workdir workdirs/mnist_da_pu_100 --config=configs/ve/mnist_ncsnpp_continuous.py --config.training.denoise_augment=True --config.data.labels_per_class=1000 --config.training.score_model=False --config.training.clf_model=True --config.training.n_iters=150000 --config.data.pu=True --config.data.pu_config.positive_classes="(1, 3, 5, 7, 9)" --config.data.pu_config.use_classes="(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)"

For training the score-network on MNIST, we can use:
>python main.py --mode train --workdir scoredirs/mnist_ve --config configs/ve/mnist_ncsnpp_continuous.py --config.training.score_model=True --config.training.clf_model=False 

After training a classifer, if you want to generate 1024 samples per class, you can use:
> python generate_samples.py --workdir cifar_da --num_samples 1024 --bs 256 --id 0

For running the sampling job in parallel, you can use the ```id``` argument: this will ensure that the npz files are not overwritten. The generated samples can be found inside ```workdir/samples```. All of the above support preemption and can resume execution. 

For the evaluation of generated CIFAR10 samples, you can run:
> python run_lib_evaluation.py --config configs/eval/eval_cifar10_configs.py --mode full --workdir cifar_da --stat --latent --fidis --prdc
