
## Main code
The main code is located in fastvideo/SRPO.py. The rest of the project is built upon existing open-source software.
## Dependencies and Installation

```bash
conda create -n SRPO python=3.10.16 -y
conda activate SRPO
bash ./env_setup.sh 
```

## Training
### Prepare Training Model
1. Pretrain Model: download the FLUX.dev.1 checkpoints from [huggingface](https://huggingface.co/black-forest-labs/FLUX.1-dev) to `./data/flux`.
```bash
mkdir data
mkdir ./data/flux
huggingface-cli login
huggingface-cli download --resume-download  black-forest-labs/FLUX.1-dev --local-dir ./data/flux
```
2. Reward Model: download the HPS-v2.1(HPS_v2.1_compressed.pt) and CLIP H-14 checkpoints from [huggingface](https://huggingface.co/xswu/HPSv2/tree/main) to `./hps_ckpt`.
```bash
mkdir ./data/hps_ckpt
huggingface-cli login
huggingface-cli download --resume-download xswu/HPSv2 HPS_v2.1_compressed.pt --local-dir ./data/hps_ckpt
huggingface-cli download --resume-download laion/CLIP-ViT-H-14-laion2B-s32B-b79K open_clip_pytorch_model.bin --local-dir ./data/hps_ckpt
```
3. (Optional) Reward Model: download the PickScore checkpoint from [huggingface](https://huggingface.co/yuvalkirstain/PickScore_v1) to `./data/ps`.
```bash
mkdir ./data/ps
huggingface-cli login
python ./scripts/huggingface/download_hf.py --repo_id yuvalkirstain/PickScore_v1  --local_dir ./data/ps
python ./scripts/huggingface/download_hf.py --repo_id laion/CLIP-ViT-H-14-laion2B-s32B-b79K --local_dir ./data/clip
```

### Prepare Training Data

```bash
# Write training prompts into ./prompts.txt. Note: For online RL, no image-text pairs are needed—only inference text.
via ./prompts.txt
# Pre-extract text embeddings from your custom training dataset—this boosts training efficiency.
bash scripts/preprocess/preprocess_flux_rl_embeddings.sh
cp videos2caption2.json  ./data/rl_embeddings
```
### Inference
Inference with our cases. Replace `model_path` in `vis.py`.
```bash
torchrun --nnodes=1 --nproc_per_node=8 \
    --node_rank 0 \
    --rdzv_endpoint $CHIEF_IP:29502 \
    --rdzv_id 456 \
    vis.py 
```
### Full-parameter Training

- HPS-v2.1 serves as the Reward Model in our reinforcement learning process.
    ```bash 
    bash scripts/finetune/SRPO_training_hpsv2.sh
    ```
- (Optional) PickScore serves as the Reward Model in our reinforcement learning process.
    ```bash
    bash scripts/finetune/SRPO_training_ps.sh
    ```
    > ⚠️ Current control words are designed for HPS-v2.1, so training with PickScore may yield suboptimal results vs. HPS due to this mismatch. 

- Run distributed training with pdsh.
  ```bash
    #!/bin/bash
    echo "$NODE_IP_LIST" | tr ',' '\n' | sed 's/:8$//' | grep -v '1.1.1.1' > /tmp/pssh.hosts
    node_ip=$(paste -sd, /tmp/pssh.hosts)
    pdsh -w $node_ip "conda activate SRPO;cd <project path>; bash scripts/finetune/SRPO_training_hpsv2.sh"
