# Time-sensitive Weight Averaging for Practical Temporal Domain Generalization

&nbsp; 
This is the authors' official PyTorch implementation for Time-sensitive Weight Averaging (TWA) method in the **NeurIPS 2023** paper **Time-sensitive Weight Averaging for Practical Temporal Domain Generalization**. 

This repo is implemented based on the authors' official PyTorch implementation for Directional Data Augmentation (DDA) method in the **AAAI 2023** paper [Foresee What You Will Learn: Data Augmentation for Domain Generalization in Non-Stationary Environments](https://arxiv.org/pdf/2301.07845.pdf).

## Implementation Details and Difference

- The repo is only used for the **Rotated MNIST** dataset. 
- The **TWA** is apply together with **DDA** (**DDA + TWA**) by default in the repo.
- As **DDA** selects the snapshot with best validation accuracy, **TWA** also selects the snapshots with top validation accuracies, which is slightly from the random sampling within the paper.
- When the domain intervals become larger (e.g. 20 or 30 degrees for Rotated MNIST), **DDA** becomes more unstable, causing more **Collapsed Results**: model outputs totally random results. 
- **This instablity also exists in the original DDA implementation, and is not involved by TWA.**
- We have to ignore the Collapsed Results to reproduce the results reported in th paper of [**DDA**](https://arxiv.org/pdf/2301.07845.pdf), with the same settings we evaluate the perfromance of **DDA + TWA**.


## Prerequisites
- PyTorch >= 1.12.1 (with suitable CUDA and CuDNN version)
- torchvision >= 0.10.0
- torchmeta >= 1.8.0
- Python3
- Numpy
- pandas

## Dataset
Rotated Gaussian and Rotated MNIST: [https://drive.google.com/file/d/1o80mLQcMHej9d-MznWjGp48QRBCyWTX9/view?usp=sharing](https://drive.google.com/file/d/1o80mLQcMHej9d-MznWjGp48QRBCyWTX9/view?usp=sharing)


## Training of DDA

Rotated MNIST experiment
```
python scripts/train.py --data_dir=../dataset --gpu 0 --algorithm DDA --dataset EDGRotatedMNIST --test_env 8 --steps 5001 --hparams "{\"env_number\":9}"

```

## Training of DDA + TWA

Rotated MNIST experiment
```
python scripts/train.py --data_dir=../dataset --gpu 0 --algorithm DDA --dataset EDGRotatedMNIST --seed 100 --test_env 8 --steps 5001 --n_save 6 --hparams "{\"env_number\":9}" --dom_dist 15 --twa --arch_steps 500 --batch_size 64

```

## Acknowledgement
This code is also implemented based on the [domainbed](https://github.com/facebookresearch/DomainBed) code.
