#!/bin/bash
# SiD-DiT [NAME] [NAME]
# Use Sana_600M_512px_diffusers by default
# Automatically selects least used GPUs

set -e  # Exit on any error

# Improved argument parsing and model selection

# Parse input arguments with clearer variable names and defaults
MODEL_SELECTION="${1:-auto}"                # Model selection (default: auto)
WEIGHTING_SCHEME="${2:-1_minus_sigma}"      # Loss weighting scheme
NUM_GPUS="${3:-4}"                          # Number of GPUs
NOISE_TYPE="${4:-fresh}"                    # Noise type
RUN_DIR="${5:-ANON/data/image_experiment/sid_flow}"


# Model selection logic with improved readability and extensibility
case "$(echo "$MODEL_SELECTION" | tr '[:upper:]' '[:lower:]')" in
    "600m"|"sana_600m")
        MODEL_LOCAL="/data/datasets/Sana_600M_512px_diffusers"
        MODEL_HF="Efficient-Large-Model/Sana_600M_512px_diffusers"
        ;;
    "sprint_0.6b"|"sprint_06b"|"0.6b"|"06b")
        MODEL_LOCAL="/data/datasets/Sana_Sprint_0.6B_1024px_teacher_diffusers"
        MODEL_HF="Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers"
        ;;
    "sprint_1.6b"|"sprint_16b"|"1.6b"|"16b")
        MODEL_LOCAL="/data/datasets/SANA_Sprint_1.6B_1024px_teacher_diffusers"
        MODEL_HF="Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers"
        ;;
    "1600m_1024"|"1.6b_1024")
        MODEL_LOCAL="/data/datasets/Sana_1600M_1024px_BF16_diffusers"
        MODEL_HF="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
        ;;
    "1600m_512"|"1.6b_512")
        MODEL_LOCAL="/data/datasets/Sana_1600M_512px_diffusers"
        MODEL_HF="Efficient-Large-Model/Sana_1600M_512px_diffusers"
        ;;
    "sd3"|"stable-diffusion-3"|"stable-diffusion-3-medium"|"sd3-medium")
        MODEL_LOCAL=""  # No local path by default
        MODEL_HF="stabilityai/stable-diffusion-3-medium-diffusers"
        ;;
    "sd3.5-medium"|"stable-diffusion-3.5-medium"|"stable-diffusion-3.5-medium-diffusers"|"sd3.5-medium")
        MODEL_LOCAL=""  # No local path by default
        MODEL_HF="stabilityai/stable-diffusion-3.5-medium"
        ;;
    "sd3.5-large"|"stable-diffusion-3.5-large"|"stable-diffusion-3.5-large-diffusers"|"sd3.5-large")
        MODEL_LOCAL=""  # No local path by default
        MODEL_HF="stabilityai/stable-diffusion-3.5-large"
        ;;
    "flux"|"flux.dev"|"FLUX.1-dev"|"FLUX")
        MODEL_LOCAL=""  # No local path by default
        MODEL_HF="black-forest-labs/FLUX.1-dev"
        ;;
    *)
        print_error "Invalid model selection: $MODEL_SELECTION"
        print_info "Available options: 600m, sprint_0.6b, sprint_1.6b, 1600m_1024, 1600m_512, sd3"
        exit 1
        ;;
esac

# Prefer local model if available, otherwise use HuggingFace Hub path
if [ -d "$MODEL_LOCAL" ]; then
    MODEL="$MODEL_LOCAL"
else
    MODEL="$MODEL_HF"
fi

# Extract model name for RUN_DIR
MODEL_NAME=$(echo "$MODEL_SELECTION" | tr '[:upper:]' '[:lower:]')
RUN_DIR="${RUN_DIR}_${MODEL_NAME}"

# Dynamically set model-specific configuration for optimal GPU usage.
# Dynamically adjust BATCH_GPU based on number of GPUs and available memory.
# For 4xA6000-ADA 48GB GPUs: BATCH_GPU is tuned for memory; increase BATCH_GPU for 80GB GPUs or more GPUs.
# You may further increase BATCH_GPU if you have more than 4 GPUs or GPUs with >48GB memory.
# Example: For 8x80GB GPUs, try BATCH_GPU=16 or higher for 512px models, and BATCH_GPU=4 or higher for 1024px models.
if echo "$MODEL" | grep -q "FLUX.1-dev"; then
    RESOLUTION=512
    BATCH_GPU=1
elif echo "$MODEL" | grep -q "stable-diffusion-3-medium"; then
    RESOLUTION=1024
    BATCH_GPU=2
elif echo "$MODEL" | grep -q "stable-diffusion-3.5-medium"; then
    RESOLUTION=1024
    BATCH_GPU=2
elif echo "$MODEL" | grep -q "stable-diffusion-3.5-large"; then
    RESOLUTION=1024
    BATCH_GPU=1
elif echo "$MODEL" | grep -q "512px"; then
    RESOLUTION=512
    if echo "$MODEL" | grep -Eiq "1.6b|1600m"; then
        BATCH_GPU=16
    else
        BATCH_GPU=16
    fi
elif echo "$MODEL" | grep -q "1024px"; then
    RESOLUTION=1024
    if echo "$MODEL" | grep -Eiq "1.6b|1600m"; then
        BATCH_GPU=4
    else
        #BATCH_GPU=4
        BATCH_GPU=8
    fi
    RUN_DIR="ANON/data/image_experiment/sid_flow_1024"
else
    # Default values if no match
    RESOLUTION=512
    BATCH_GPU=8
fi

echo "[INFO] Configured model: Resolution = ${RESOLUTION}px, Batch size per GPU = $BATCH_GPU"
            

# Colors for output
RED='\033[0;31m'
[NAME]='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # [NAME] [NAME]

# Function to print colored output
print_info() {
    echo -e "${BLUE}[INFO]${NC} $1"
}

print_success() {
    echo -e "${[NAME]}[SUCCESS]${NC} $1"
}

print_warning() {
    echo -e "${YELLOW}[WARNING]${NC} $1"
}

print_error() {
    echo -e "${RED}[ERROR]${NC} $1"
}

# Function to check if command exists
command_exists() {
    command -v "$1" >/dev/null 2>&1
}

# Function to get total number of available GPUs
get_total_gpus() {
    if ! command_exists nvidia-smi; then
        print_error "nvidia-smi not found. Please ensure NVIDIA drivers are installed."
        exit 1
    fi
    
    local total_gpus
    total_gpus=$(nvidia-smi --list-gpus | wc -l)
    echo $total_gpus
}

# Function to get GPU memory usage and select least used GPUs
select_least_used_gpus() {
    if ! command_exists nvidia-smi; then
        print_error "nvidia-smi not found. Please ensure NVIDIA drivers are installed."
        exit 1
    fi
    
    local total_gpus
    total_gpus=$(get_total_gpus)
    
    # Validate requested number of GPUs
    if [ "$NUM_GPUS" -gt "$total_gpus" ]; then
        print_warning "Requested $NUM_GPUS GPUs but only $total_gpus available. Using $total_gpus GPUs."
        NUM_GPUS=$total_gpus
    fi
    
    if [ "$NUM_GPUS" -lt 1 ]; then
        print_error "Number of GPUs must be at least 1"
        exit 1
    fi
    
    print_info "Analyzing GPU usage to select $NUM_GPUS least used GPUs..."
    
    # Get GPU memory usage and sort by memory utilization (least used first)
    # Format: GPU_ID MEMORY_USED_MB
    local gpu_usage
    gpu_usage=$(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \
                     awk -F', ' '{print $1, $2}' | \
                     sort -k2,2n | \
                     head -$NUM_GPUS)
    
    if [ -z "$gpu_usage" ]; then
        print_error "Failed to get GPU usage information"
        exit 1
    fi
    
    # Extract GPU IDs
    local selected_gpus
    selected_gpus=$(echo "$gpu_usage" | awk '{print $1}' | tr '\n' ',' | sed 's/,$//')
    
    # Get memory usage for display
    print_info "Selected GPUs (least memory usage):"
    echo "$gpu_usage" | while read gpu_id mem_used; do
        print_info "  GPU $gpu_id: ${mem_used}MB used"
    done
    
    # Set CUDA_VISIBLE_DEVICES
    export CUDA_VISIBLE_DEVICES="$selected_gpus"
    
    print_success "Using $NUM_GPUS GPUs: $selected_gpus"
}

# Function to check CUDA availability
check_cuda() {
    if ! command_exists python; then
        print_error "Python not found"
        exit 1
    fi
    
    if python -c "import torch; print(torch.cuda.is_available())" 2>/dev/null | grep -q "True"; then
        print_success "CUDA is available"
    else
        print_error "CUDA is not available. Please check PyTorch installation."
        exit 1
    fi
}

# Function to create output directory
create_output_dir() {
    # Check if parent directory is writable
    local parent_dir=$(dirname "$RUN_DIR")
    if [ ! -w "$parent_dir" ] 2>/dev/null; then
        print_error "Cannot write to parent directory: $parent_dir"
        exit 1
    fi
    
    if [ ! -d "$RUN_DIR" ]; then
        mkdir -p "$RUN_DIR"
        print_info "Created output directory: $RUN_DIR"
    fi
}

# Function to check dataset paths
check_datasets() {
    print_info "Checking dataset paths..."
    
    local dataset1="ANON/data/datasets/MS-COCO-256/val"
    local dataset2="ANON/data/datasets/aesthetics_6_plus"
    local dataset3="ANON/data/datasets/midjourney-v6-llava/data"
    
    local missing=0

    if [ ! -d "$dataset1" ]; then
        print_warning "Dataset path not found: $dataset1"
        print_info "If you need to compute metrics during training, please provide MS-COCO-256/val."
        missing=1
    fi
    
    if [ ! -d "$dataset2" ]; then
        print_warning "Dataset path not found: $dataset2"
        print_info "If running data-free, you may provide either $dataset2 or $dataset3."
        missing=1
    fi
    
    if [ ! -d "$dataset3" ]; then
        print_warning "Dataset path not found: $dataset3"
        print_info "If using a Diffusion-GAN that needs real images, $dataset3 is required."
        print_info "If running data-free, you may provide either $dataset2 or $dataset3."
        missing=1
    fi
    
    if [ "$missing" -eq 1 ]; then
        print_warning "Some datasets are missing. Training may fail if they are required."
        print_info "Continuing anyway..."
    else
        print_success "All datasets found"
    fi
}

# Function to show configuration
show_config() {
    print_info "=== SiD-SANA [NAME] [NAME] ==="
    print_info "Model: $MODEL"
    print_info "GPUs: $NUM_GPUS (automatically selected)"
    print_info "Selected GPU IDs: $CUDA_VISIBLE_DEVICES"
    print_info "Output directory: $RUN_DIR"
    print_info "Training data: /data/datasets/midjourney-v6-llava/data"
    print_info "Prompt data: /data/datasets/aesthetics_6_plus"
    print_info "Validation data: /data/datasets/MS-COCO-256/val"
    print_info "Noise type: $NOISE_TYPE"
    print_info "Weighting scheme: $WEIGHTING_SCHEME"
    echo
}

# Function to setup environment
setup_environment() {
    print_info "Setting up environment..."
    
    # Set PyTorch environment variables for better performance
    #export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
    #export CUDA_LAUNCH_BLOCKING=0
    
    print_success "Environment setup complete"
}


start_training() {
    print_info "Starting distributed training on selected GPUs..."
    print_info "Command parameters: RESOLUTION=$RESOLUTION, BATCH_GPU=$BATCH_GPU"
    torchrun \
        --standalone \
        --nproc_per_node=$NUM_GPUS \
        sid_dit_flux_train.py \
        --outdir "$RUN_DIR" \
        --resume "$RUN_DIR" \
        --data "ANON/data/datasets/MS-COCO-256/val" \
        --data_prompt_text "ANON/data/datasets/aesthetics_6_plus" \
        --text_image_pair_path "ANON/data/datasets/midjourney-v6-llava/data" \
        --dit_model "$MODEL" \
        --optimizer '[NAME]' \
        --resolution "$RESOLUTION" \
        --batch $(if echo "$MODEL" | grep -q "3.5-large"; then echo 256; else echo 256; fi) \
        --cpu_offload $(if echo "$MODEL" | grep -q "3.5-large"; then echo 0; else echo 0; fi) \
        --batch-gpu "$BATCH_GPU" \
        --lr 0.00001 \
        --glr 0.00001 \
        --cfg_train_fake 4.5 \
        --cfg_eval_fake 4.5 \
        --cfg_eval_real 4.5 \
        --alpha 1.0 \
        --init_timestep 999 \
        --num_steps 4 \
        --fp16 0 \
        --bf16 1 \
        --autocast_bf16 1 \
        --gradient_checkpointing 1 \
        --tick 2 \
        --snap 10 \
        --dump 10 \
        --duration 2 \
        --ls 1 \
        --lsg 100 \
        --metrics "fid_clip_10k_full" \
        --noise_type "$NOISE_TYPE" \
        --weighting_scheme "$WEIGHTING_SCHEME" \
        --train_diffusiongan 0 \
        --use_sd3_shift 0 \
        --nosubdir #\
        #--workers 0 #\
        #--text_encoders_dtype 'fp16'
    }

# Function to handle cleanup on exit
cleanup() {
    print_info "Cleaning up..."
    # Add any cleanup tasks here if needed
}

# Set trap for cleanup
trap cleanup EXIT

# Main execution
main() {
    print_info "Starting SiD-SANA training with Sana_600M_512px_diffusers"
    
    # Validate weighting scheme first (since it's the first parameter)
    case $WEIGHTING_SCHEME in
        "sid_legacy"|"snr_sqrt"|"snr"|"1_over_sigma2"|"1_over_sigma"|"1_minus_sigma_squared"|"1_minus_sigma")
            ;;
        *)
            print_error "Invalid weighting_scheme: $WEIGHTING_SCHEME. Must be one of: sid_legacy, snr_sqrt, snr, 1_over_sigma2, 1_over_sigma, 1_minus_sigma_squared, 1_minus_sigma"
            exit 1
            ;;
    esac
    
    # Validate number of GPUs
    if ! echo "$NUM_GPUS" | grep -qE '^[0-9]+$'; then
        print_error "Number of GPUs must be a positive integer"
        exit 1
    fi
    
    # Validate noise type
    case $NOISE_TYPE in
        "fresh"|"fixed"|"ddim")
            ;;
        *)
            print_error "Invalid noise_type: $NOISE_TYPE. Must be one of: fresh, fixed, ddim"
            exit 1
            ;;
    esac
    
    # Pre-flight checks
    select_least_used_gpus
    check_cuda
    check_datasets
    create_output_dir
    
    # Setup and start
    show_config
    setup_environment
    start_training
    
    print_success "Training completed successfully!"
}

# Show usage if help requested
if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then
    echo "Usage: $0 [MODEL_SELECTION] [WEIGHTING_SCHEME] [NUM_GPUS] [NOISE_TYPE] [RUN_DIR]"
    echo "  MODEL_SELECTION: Model to use (default: auto)"
    echo "    Options: auto, 600m, sprint_0.6b, sprint_1.6b, 1600m_1024, 1600m_512"
    echo "    auto: Automatically select first available model"
    echo "    600m: Sana_600M_512px_diffusers"
    echo "    sprint_0.6b: Sana_Sprint_0.6B_1024px_teacher_diffusers"
    echo "    sprint_1.6b: SANA_Sprint_1.6B_1024px_teacher_diffusers"
    echo "    1600m_1024: Sana_1600M_1024px_BF16_diffusers"
    echo "    1600m_512: Sana_1600M_512px_diffusers"
    echo "  WEIGHTING_SCHEME: Loss weighting scheme (default: 1_minus_sigma)"
    echo "    Options: sid_legacy, snr_sqrt, snr, 1_over_sigma2, 1_over_sigma, 1_minus_sigma_squared, 1_minus_sigma"
    echo "  NUM_GPUS: Number of GPUs to use (default: 4)"
    echo "  NOISE_TYPE: Noise type for generation (default: fresh)"
    echo "    Options: fresh, fixed, ddim"
    echo "  RUN_DIR: Output directory (default: /data/image_experiment/sid_flow) - can be omitted"
    echo
    echo "This script automatically selects the N GPUs with the least memory usage."
    echo "Example: $0 auto"
    echo "Example: $0 600m 1_minus_sigma"
    echo "Example: $0 sprint_1.6b sid_legacy"
    echo "Example: $0 1600m_1024 snr_sqrt 8 ddim"
    echo "Example: $0 600m 1_minus_sigma 1 fresh /path/to/output"
    echo "Example: $0 1600m_512 1_minus_sigma 1 fresh /path/to/output"
    exit 0
fi

# Run main function
main "$@"
```

**Usage examples:**
```bash
# Use all defaults (auto-select model)
sh run_sid_dit_flux_ANON.sh flux 1_minus_sigma 8

#use sd3.5-medium
sh run_sid_dit_flux_ANON.sh flux 1_minus_sigma 8

#use sd3.5-large
sh run_sid_dit_flux_ANON.sh flux 1_minus_sigma 8

# Specify model selection only
sh run_sid_dit_flux_ANON.sh flux   1_minus_sigma 4

# Specify model and weighting scheme
sh run_sid_dit_flux_ANON.sh flux sid_legacy 8

# Specify model, weighting scheme, and GPUs
sh run_sid_dit_flux_ANON.sh flux 1_minus_sigma 2

# Specify model, weighting scheme, GPUs, and noise type
sh run_sid_dit_flux_ANON.sh flux snr_sqrt 8 ddim

# Specify all parameters
sh run_sid_dit_flux_ANON.sh flux sid_legacy 4 fixed /path/to/output

# Specify model and custom output directory
sh run_sid_dit_flux_ANON.sh flux 1_minus_sigma 1 fresh /path/to/output