#!/bin/bash

# =================================================================
# 1. 实验环境配置
# =================================================================

# --- 硬件 ---
# ⚠️ 警告：(0 0 0) + BatchSize 64 极易显存溢出，请根据实际显卡能力调整
GPUS=(0 1 2) 

# --- 基础设置 ---
G_NUM_CLIENTS=10
G_FRACTION=0.5
G_ROUNDS=50
G_LOCAL_EPOCHS=1
G_BATCH_SIZE=64

# --- 模型与任务 ---
TRAINING_MODE="head"
MODEL_NAME="resnet18"
OPTIMIZER="sgd"
ALPHA=10 # 控制数据异质性的参数 0.1 0.5 10

# --- 任务列表 ---
DATASETS=("mnist" "fashionmnist" "cifar10")
METHODS_LIST=("FedAvg" "FedProx" "FedSophia" "FedNew" "FedDANE" "FedNewton")
SEEDS=(0 1 2)

# --- 静态通用参数 ---
SOPHIA_LR=0.05
SOPHIA_RHO=0.04
BETAS="0.9,0.99"
NEWTON_HESSIAN_BATCHES=64
DANE_MU=0.01
DANE_LR=0.1
FIXED_LR=0.01
FIXED_MU=0.01
NEWTON_DAMPING=0.0001
NEWTON_LR=0.1

# =================================================================
# 2. 目录初始化与并发控制
# =================================================================

TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
RUN_ROOT="runs/run_${TIMESTAMP}_${MODEL_NAME}_${TRAINING_MODE}"
BASE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

mkdir -p "${RUN_ROOT}"
cp "${BASE_DIR}/main.py"     "${RUN_ROOT}/main.py"     2>/dev/null || true
cp "${BASE_DIR}/analysis.py" "${RUN_ROOT}/analysis.py" 2>/dev/null || true
cp "${BASE_DIR}/run.sh"      "${RUN_ROOT}/run.sh"      2>/dev/null || true

# --- 初始化 FIFO 管道 ---
tmp_fifo="/tmp/$$.fifo"
mkfifo $tmp_fifo
exec 6<>$tmp_fifo
rm $tmp_fifo

# 填入令牌
for gpu in "${GPUS[@]}"; do
    echo "$gpu" >&6
done


# 捕捉中断信号
trap 'echo ">>> Terminating..."; kill 0; exit' SIGINT SIGTERM

echo "=================================================="
echo "       Start Training (Corrected Version)"
echo "=================================================="

# =================================================================
# 3. 主循环 (训练阶段)
# =================================================================

for DATASET in "${DATASETS[@]}"; do
    echo "##################################################"
    echo "Queuing Dataset: ${DATASET}"
    echo "##################################################"


    # 预检下载
    # python main.py --method FedAvg --dataset "${DATASET}" --model_name "${MODEL_NAME}" --rounds 0 --num_clients 2 > /dev/null 2>&1
    
    LOG_DIR="${RUN_ROOT}/${DATASET}/logs"
    FIG_DIR="${RUN_ROOT}/${DATASET}/figs"
    mkdir -p "${LOG_DIR}"
    mkdir -p "${FIG_DIR}"

    for METHOD in "${METHODS_LIST[@]}"; do
        for seed in "${SEEDS[@]}"; do
            
            EXTRA_ARGS=""
            case $METHOD in
                "FedAvg")   EXTRA_ARGS="--lr $FIXED_LR" ;;
                "FedProx")  EXTRA_ARGS="--lr $FIXED_LR --mu $FIXED_MU" ;;
                "FedSophia"|"FedNew") EXTRA_ARGS="--lr $FIXED_LR --sophia_lr $SOPHIA_LR --rho $SOPHIA_RHO --betas $BETAS" ;;
                "FedDANE")  EXTRA_ARGS="--lr $FIXED_LR --newton_lr $DANE_LR --mu $DANE_MU" ;;
                "FedNewton") EXTRA_ARGS="--newton_lr $NEWTON_LR --damping $NEWTON_DAMPING --hessian_batches $NEWTON_HESSIAN_BATCHES" ;;
                *) echo "Skipping: $METHOD"; continue ;;
            esac

            # 1. 获取令牌
            read -u6 gpu_id
            
            echo "    -> [GPU $gpu_id] Submitted: $METHOD (Seed $seed, Data $DATASET)"

            # 2. 后台执行任务
            {
                CMD="python -u main.py \
                    --method $METHOD \
                    --training_mode $TRAINING_MODE \
                    --dataset $DATASET \
                    --model_name $MODEL_NAME \
                    --seed $seed \
                    --num_clients $G_NUM_CLIENTS \
                    --fraction $G_FRACTION \
                    --rounds $G_ROUNDS \
                    --local_epochs $G_LOCAL_EPOCHS \
                    --batch_size $G_BATCH_SIZE \
                    --optimizer $OPTIMIZER \
                    --log_dir $LOG_DIR \
                    --alpha $ALPHA \
                    $EXTRA_ARGS"

                CUDA_VISIBLE_DEVICES=$gpu_id $CMD > "${LOG_DIR}/log_${METHOD}_s${seed}.txt" 2>&1
                
                echo "    <- [GPU $gpu_id] Finished:  $METHOD (Seed $seed, Data $DATASET)"
                
                # 3. 归还令牌
                echo "$gpu_id" >&6
            } & 

        done
    done
done

echo ">>> All tasks submitted. Waiting for completion..."
# 等待所有数据集的所有任务完成
wait 

# 关闭管道
exec 6>&-

# =================================================================
# 4. 结果分析阶段 (修复了变量作用域问题)
# =================================================================

echo "=================================================="
echo "       Start Analysis"
echo "=================================================="

for DATASET in "${DATASETS[@]}"; do
    echo "Generating Analysis for ${DATASET}..."
    
    # [修复] 必须在这里重新定义路径，不能使用上面的 stale 变量
    CURRENT_LOG_DIR="${RUN_ROOT}/${DATASET}/logs"
    CURRENT_FIG_DIR="${RUN_ROOT}/${DATASET}/figs"

    if [ -f "${RUN_ROOT}/analysis.py" ]; then
        python "${RUN_ROOT}/analysis.py" --log_dir "${CURRENT_LOG_DIR}" --fig_dir "${CURRENT_FIG_DIR}"
    else
        python "${BASE_DIR}/analysis.py" --log_dir "${CURRENT_LOG_DIR}" --fig_dir "${CURRENT_FIG_DIR}"
    fi
done

echo "All experiments completed."
