# DisCoNet

DisCoNet is a generative model that combines the power of Variational Autoencoders (VAEs) with adversarial training. VAEs are a type of deep generative model that can learn to generate new data samples by capturing the underlying distribution of the training data. Adversarial training, on the other hand, involves training a discriminator network to distinguish between real and generated samples, while simultaneously training the generator network to fool the discriminator.

## Parameters

<center>

| Argument                      | Description                                        | Default  | Choices                          |
|-------------------------------|----------------------------------------------------|----------|----------------------------------|
| `--dataset`                   | Dataset name                                       | `mnist`  | `mnist`, `cifar10`,`tinyimagenet`|
| `--batch_size`                | Batch size                                         | `128`    |                                  |
| `--n_epochs`                  | Number of epochs                                   | `100`    |                                  |
| `--lr`                        | Learning rate                                      | `0.0002` |                                  |
| `--latent_dim`                | Latent dimension                                   | `128`    |                                  |
| `--hidden_dims`               | Hidden dimensions                                  | `None`   |                                  |
| `--checkpoint`                | Checkpoint path                                    | `None`   |                                  |
| `--num_samples`               | Number of samples                                  | `16`     |                                  |
| `--gen_weight`                | Generator weight                                   | `0.002`  |                                  |
| `--recon_weight`              | Reconstruction weight                              | `0.002`  |                                  |
| `--sample_and_save_frequency` | Sample and save frequency                          | `5`      |                                  |
| `--discriminator_checkpoint`  | Discriminator checkpoint path                      | `None`   |                                  |
| `--ood_task`                  | Type of OOD detection task                         | `near`   | `near`, `far`, `covar`           |

</center>

You can find out more about the parameters by checking [`util.py`](./../src/disconet/utils/util.py) or by running the following command on the example script:

    python train_disconet.py --help

## Training

To replicate the experiments performed in the paper, please use the following commands:

**MNIST**

    python train_disconet.py --dataset mnist --batch_size 512 --hidden_dims 64 128 256 --latent_dim 512 --n_epochs 200 --lr 1e-4 --gen_weight 1e-3 --recon_weight 1e-3

**CIFAR-10**

    python train_disconet.py --dataset cifar10 --batch_size 512 --hidden_dims 64 128 256 512 --latent_dim 1024 --n_epochs 250 --lr 5e-4 --gen_weight 1e-3 --recon_weight 1e-3

**TinyImageNet**

    python train_disconet.py --dataset tinyimagenet --batch_size 512 --hidden_dims 64 128 256 512 --latent_dim 1024 --n_epochs 140 --lr 5e-4 --gen_weight 1e-3 --recon_weight 1e-3

**ImageNet-1K**

    python train_disconet.py --dataset imagenet --batch_size 1256 --hidden_dims 64 128 256 512 1024 --latent_dim 1024 --n_epochs 140 --lr 1e-4 --gen_weight 2e-4 --recon_weight 2e-4

## OOD Detection

To perform OOD detection you must indicate your ID dataset, the type of OOD detection task you want to perform and provide the discriminator checkpoint:

    python eval_disconet.py --ood_task far --dataset mnist --hidden_dims 64 128 256 --latent_dim 512 --discriminator_checkpoint ./../../models/DisCoNet/Discriminator_mnist.pt