# Theoretical generalization bounds for improving the efficiency of deep online training

## Environment setup (Python3)
1. [Install Pytorch](https://pytorch.org/get-started/locally/)
2. Install required packages
```
pip install -r requirements.txt
```

## Data
CIFAR10 and CIFAR100 datasets are downloaded automatically in the code. You can also download them manually to the `./data/` directory using this link: https://www.cs.toronto.edu/~kriz/cifar.html .

## Usage
1. Micro-batch online learning experiments:
```
python main.py --n_epochs 50 --lr 0.0001 --batch_size 256 --loss_types mse --seeds 3 4 5 --noise_rates 0 0.2 0.4 0.6 --T 2 3 4 --dataset CIFAR10 --n_workers_per_dl 4
```
- **Args:**
    - `--loss_types`: List of loss functions used for training.
    - `--noise_rates`: List of symmetric noise rates.
    - `--seeds`: List of random seeds for reproducibility.
    - `--T`: 2^T microbatches.
    - `--batch_size`: Training batch size per routine.
    - `--n_epochs`: Training epochs per routine.
    - `--lr`: Learning rate.
    - `--n_workers_per_dl`: No. workers used for each dataloader, defaulted to half of the currently free CPUs.
    - `--dataset`: Experimenting dataset.

2. Plotting results
```
python plot.py --loss_types mse --noise_rates 0, 0.2, 0.4, 0.6 --T 2 3 4 --dataset CIFAR10
```
- **Args:**:
    - `--loss_types`: List of loss functions used for training.
    - `--noise_rates`: List of symmetric noise rates.
    - `--T`: separate the rest haft of training set into 2^T microbatches.
    - `--dataset`: Experimenting dataset.
