#!/bin/bash
# Aircraft Mask Weighted Training Script
# 使用Masks2文件夹，带mask加权loss功能

# *[Specify the GPU devices to use]
export CUDA_VISIBLE_DEVICES=1,2

# *[Set the path to the training config file]
export OMINI_CONFIG="./train/config/aircraft_mask_weighted.yaml"

# *[Set WandB API key and project name]
export WANDB_API_KEY='015c479c983ae5c9f49e0589cc2925160d471c6c'
export WANDB_PROJECT='OminiControl_MaskWeighted'  # WandB 项目名
export WANDB_NAME='aircraft_mask_weighted'         # WandB 运行名称

# Print configuration
echo "=============================================="
echo "Aircraft Mask Weighted Training"
echo "=============================================="
echo "GPU: $CUDA_VISIBLE_DEVICES"
echo "Config: $OMINI_CONFIG"
echo "WandB Project: $WANDB_PROJECT"
echo "Using: Masks2 folder"
echo "Loss: Mask-weighted MSE"
echo "=============================================="
export TOKENIZERS_PARALLELISM=true

# *[Launch the training script]
accelerate launch --main_process_port 41354 \
    -m omini.train_flux.train_aircraft_mask_weighted
