#!/bin/bash
# Aircraft Mask Weighted Training Script (Normalized Loss)
# 使用Masks2文件夹，带归一化mask加权loss功能
# Loss公式: loss = sum(diff^2 * weights) / sum(weights)

# *[Specify the GPU devices to use]
export CUDA_VISIBLE_DEVICES=2

# *[Set the path to the training config file]
export OMINI_CONFIG="./train/config/aircraft_mask_weighted_normalized.yaml"

# *[Set WandB API key and project name]
export WANDB_API_KEY='015c479c983ae5c9f49e0589cc2925160d471c6c'
export WANDB_PROJECT='OminiControl_MaskWeighted'  # WandB 项目名
export WANDB_NAME='aircraft_mask_weighted_normalized'         # WandB 运行名称

# Print configuration
echo "============================================="
echo "Aircraft Mask Weighted Training (Normalized)"
echo "============================================="
echo "GPU: $CUDA_VISIBLE_DEVICES"
echo "Config: $OMINI_CONFIG"
echo "WandB Project: $WANDB_PROJECT"
echo "Using: Masks2 folder"
echo "Loss: Normalized Mask-weighted MSE"
echo "Lambda Weight: 20.0"
echo "============================================="
export TOKENIZERS_PARALLELISM=true

# *[Launch the training script]
accelerate launch --main_process_port 41357 \
    -m omini.train_flux.train_aircraft_mask_weighted_normalized
