#!/bin/bash
export PYTHONPATH=$PYTHONPATH:$(pwd)/src

# GPU configuration (fixed)
FIXED_GPUS="0,1,2,3"

# Define custom seed list
SEEDS=(42 43 44)

# Check if a specific dataset name is provided
if [ $# -eq 1 ]; then
    # Train the specified dataset
    DATASET_NAME=$1
    echo "Starting training for specified dataset: $DATASET_NAME"
    echo "Time: $(date)"
    
    # Check if dataset directory exists
    if [ -d "regression_data/$DATASET_NAME" ]; then
        # Iterate over all seeds
        for SEED in "${SEEDS[@]}"; do
            echo ""
            echo "Using seed: $SEED"
            echo "Time: $(date)"
            # Run training using accelerate launch
            CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --num_processes=4 --main_process_port=12359 \
                -m src.train.train_search_rl_expert \
                +experiment=RL \
                dataset.name="'$DATASET_NAME'" \
                dataset.params.data_dir=regression_data \
                eval_mode=false \
                skip_mode=false \
                num_epochs=200 \
                batch_size=128 \
                save_every_n_epochs=114 \
                seed=$SEED
            if [ $? -eq 0 ]; then
                echo "✓ Dataset $DATASET_NAME (seed=$SEED) training completed"
            else
                echo "✗ Dataset $DATASET_NAME (seed=$SEED) training failed"
            fi
            echo "Time: $(date)"
        done
    else
        echo "× Dataset $DATASET_NAME does not exist"
    fi
    
    echo "Time: $(date)"
else
    # Iterate over all tasks in regression_data
    echo "Starting to read all tasks in regression_data"
    echo "Time: $(date)"
    
    # Get all dataset names
    datasets=()
    for dir in regression_data/*/; do
        if [ -d "$dir" ]; then
            dataset_name=$(basename "$dir")
            datasets+=("$dataset_name")
            echo "√ Dataset $dataset_name"
        fi
    done
    
    # Sort dataset names
    IFS=$'\n' datasets=($(sort <<<"${datasets[*]}"))
    unset IFS
    echo "Found ${#datasets[@]} datasets:"
    for i in "${!datasets[@]}"; do
        echo "  $((i+1)). ${datasets[i]}"
    done
    
    # Create task queue: all combinations of dataset × seed
    echo "Creating task queue..."
    task_queue=()
    for dataset_name in "${datasets[@]}"; do
        for seed in "${SEEDS[@]}"; do
            task_queue+=("$dataset_name:$seed")
        done
    done
    
    echo "Total ${#task_queue[@]} tasks to execute"
    echo "Using GPU: $FIXED_GPUS"
    echo "Time: $(date)"
    
    # Execute all tasks sequentially
    task_number=0
    for task in "${task_queue[@]}"; do
        # Parse task
        IFS=':' read -r dataset_name seed <<< "$task"
        
        task_number=$((task_number + 1))
        echo ""
        echo "============================================================"
        echo "Task $task_number/${#task_queue[@]}: $dataset_name (seed=$seed)"
        echo "Using GPU: $FIXED_GPUS"
        echo "Time: $(date)"
        echo "============================================================"
        
        # Run training task directly
        CUDA_VISIBLE_DEVICES=$FIXED_GPUS accelerate launch --num_processes=4 --main_process_port=12359 \
            -m src.train.train_search_rl_expert \
            +experiment=RL \
            dataset.name="'$dataset_name'" \
            dataset.params.data_dir=regression_data \
            eval_mode=false \
            skip_mode=false \
            reinforce.expert_ce_weight=0.05 \
            num_epochs=100 \
            batch_size=128 \
            save_every_n_epochs=114 \
            seed=$seed
        
        if [ $? -eq 0 ]; then
            echo "✓ Dataset $dataset_name (seed=$seed) training completed"
        else
            echo "✗ Dataset $dataset_name (seed=$seed) training failed"
        fi
        echo "Time: $(date)"
    done
    
    echo ""
    echo "============================================================"
    echo "All seeds training for all datasets completed!"
    echo "Time: $(date)"
    echo "============================================================"
fi