# CoGAN: Collaborative GAN

## Introduction
This repository contains code for the paper titled "CoGANs: Collaborative Generative Adversarial Networks"
In this type of GAN, we try to develop a synergy between multiple generators to efficiently and robustly generate multi-modal datasets.

## Requirements and usage instructions
Before running any code, please make sure you have set up your workspace using the `requirements.txt` supplied in the zip file.

The zip file contains 5 python files - 
1. `hyperparameters.py` - This file contains all the hyperparameters required for the training and evaluation of the network.
2. `utils.py` - Contains the helper functions and classes, such as custom callbacks, metrics and dataloaders.
3. `model.py` - Contains all the model related definitions such as the structure, `train_step` to run, etc.
4. `train.py` - This is the main script to run when you want to train the model. This script defines the loss functions for each individual model and calls the `model.fit()` method, to execute the `train_step`. It takes an optional argument `-c` or `--from_checkpoint`, which can be used if you want to load pre-trained models (stored in `./models/`).
5. `evaluate.py` - This is the main script to run when you want to evaluate the metrics such as FID, TVD (pairwise) and average class probabilities. This script will automatically load the models stored in the `./models/` directory and use them to evaluate each generator individually and together. It will print all the metrics on the console.

The directory structure is as defined below - 

- `generated_images` - This directory contains the generated images, both during training and evaluation. The subfolder `gen` contains the images generated during evaluation while `gen(i)` contains images generated by the i'th generator during training at `save_freq` intervals.
- `logs` - contains log files written during training by the `tensorboard_callback`.
- `models` - contains all the pre-trained models.

To train the model from scratch, run the line:

`python train.py`

To use pre-trained models and continue training:

`python train.py -c` OR `python train.py --from_checkpoint`

To evaluate metrics (This will only use pre-trained models from `./models/`):

`python evaluate.py`

## Pre-trained models

The pre-trained models for each experiment are hosted on the link:

- [Pre-trained models](https://drive.google.com/drive/folders/1lQjAowsZ9pyw_8w3SkbRz7UUzc3_n59x?usp=sharing) - trained using the default hyperparameters mentioned in the paper.
