# Pattern-Guided Diffusion Models
This is the code for Pattern-Guided Diffusion Models.

## Environment
We recommend Conda. Follow the steps to set up the environment.
1. `conda create --name pgdm python=3.12.2 pip`
2. `conda activate pgdm`
4. `pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 einops==0.8.1 tensorboardX==2.1 protobuf==3.20.* matplotlib tqdm tensorboard pandas numpy==1.26.4 opencv-python archetypes torchmetrics transformers[torch]`
5. `pip install pip install git+https://github.com/aleixalcacer/archetypes.git@3fb66fdbe927b090a4a026751471615111942319`

## Datasets
The preprocessed UWHVF data can be found at `src/datasets/uwhvf_seqp_3`, but the raw data can be downloaded from [here](https://github.com/uw-biomedical-ml/uwhvf/blob/0c07384b1345aca702f503a959d8815ff0bfa17a/alldata.json).

Download the raw 3D keypoints and motion data for the AIST++ data from [here](https://google.github.io/aistplusplus_dataset/index.html). Unzip the data and place it in the `src/raw_data` folder. The result should be two directories, `src/raw_data/keypoints3d` and `src/raw_data/motions`. Preprocess the data by running `src/preprocess_aistpp.sh PATH_TO_RAW_DATA`.

The pre-extracted archetypal patterns are located at `src/*_aa_object.pkl`.

## Training and Evaluation

1. Train the pattern prediction model.
    ```
    python3 train_pattern_prediction.py --dataset DSET --batch_size BATCH_SIZE --lr LR --epochs EPOCHS --patience PATIENCE --n_frames HISTORY_STEPS --n_horizon HORIZON_STEPS
    ```
2. Evaluate the pattern prediction model
    ```
    python3 test_pattern_prediction.py --tr_dataset DSET --te_dataset uwhvf --model_pth PATTERN_MODEL_PTH --n_frames HISTORY_STEPS --n_horizon HORIZON_STEPS
    ```
3. Train PGDM. For UWHVF, use the representation td. For AIST++, use the representation kp_norm.
    ```
    python3 train_diffusion.py --dataset DSET --representation REPRESENTATION --batch_size BATCH_SIZE --lr LR --pattern_model_pth PATTERN_MODEL_PTH --n_frames HISTORY_STEPS --n_horizon HORIZON_STEPS
    ```
4. Evaluate PGDM.
    ```
    python3 test_diffusion.py --tr_dataset DSET --te_dataset DSET --representation REPRESENTATION --split test --pattern_model_pth PATTERN_MODEL_PTH --model_pth PGDM_PTH --guidance_type relu --guidance_scale W_BAR --error_scale GAMMA --guide_mix --n_frames HISTORY_STEPS --n_horizon HORIZON_STEPS
    ```