
Code for "Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments"


<!-- TOC -->

- [Overview](#overview)
- [Prerequisites Installation](#prerequisites-installation)
- [Datasets Preparation](#datasets-preparation)
- [Code Running](#code-running)

<!-- /TOC -->


## Environments Preparation

* The code is implemented with `Python 3.7.16`, `CUDA 12.2`. To try out this project, it is recommended to set up a virtual environment first.

    ```bash
    # Step-by-step installation
    conda create --name wdiff python=3.7.16
    conda activate wdiff

    # this installs the right pip and dependencies for the fresh python
    conda install -y ipython pip
  
    # install torch, torchvision and torchaudio
    pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html

    # this installs required packages
    pip install -r requirements.txt
    ```


## Datasets

- Download [yearbook.pkl](https://drive.google.com/u/0/uc?id=1mPpxoX2y2oijOvW1ymiHEYd7oMu2vVRb&export=download)
- Download [fmow.pkl](https://drive.google.com/u/0/uc?id=1s_xtf2M5EC7vIFhNv_OulxZkNvrVwIm3&export=download) and [fmow_v1.1.tar.gz](https://worksheets.codalab.org/bundles/0xaec91eb7c9d548ebb15e1b5e60f966ab)
- Download [huffpost.pkl]( https://drive.google.com/u/0/uc?id=1jKqbfPx69EPK_fjgU9RLuExToUg7rwIY&export=download)
- Download [arxiv.pkl](https://drive.google.com/u/0/uc?id=1H5xzHHgXl8GOMonkb6ojye-Y2yIp436V&export=download)
- ONP and Moons will be provided in the "datasets" folder in the future.
- rmnist will automatically download while running the code.

The data folder should be structured as follows:
    
    ```
    ├── datasets/
    │   ├── yearbook/     
    |   |   ├── yearbook.pkl
    │   ├── rmnist/
    |   |   ├── MNIST/
    |   |   ├── rmnist.pkl
    │   ├── ONP/	
    |   |   ├── processed/
    │   ├── Moons/	
    |   |   ├── processed/
    │   ├── huffpost/	
    |   |   ├── huffpost.pkl
    │   ├── fMoW/	
    |   |   ├── fmow_v1.1/
    |   |   |   |── images/
    |   |   |—— fmow.pkl
    │   ├── arxiv/	
    |   |   ├── arxiv.pkl
    ```

## Training and Testing
    ```bash
    # running for rmnist dataset:
    python3 main.py --cfg ./configs/cfg_rmnist.yaml device 4 DM.params.base_learning_rate 5e-4 trainer.warm_up 0.2 trainer.num_DM_loop 1 trainer.tradeoff_inv 10.0 trainer.lr 1e-3 DM.params.unet_config.params.model_channels 64 DM.params.unet_config.params.num_head_channels 32 DM.params.unet_config.params.num_groups 32 trainer.sample_num 32
    ```
