AFS_list=(True)
DATASET_list=("cifar10")
NUM_STEPS_list=(4 5 6 7 8)
FREEZE_NET_list=(True)
TOTAL_KIMG_list=(5)
LR_LIST=(5e-5)
ITER_LIST=(160)
RANK_LIST=(16)
SCA_LIST=(0.5)
BEGIN_ID=0

for AFS in "${AFS_list[@]}"; do
    for DATASET in "${DATASET_list[@]}"; do
        for FREEZE_NET in "${FREEZE_NET_list[@]}"; do
            for LR in "${LR_LIST[@]}"; do
                for ITER in "${ITER_LIST[@]}"; do
                    for RANK in "${RANK_LIST[@]}"; do
                        for SCALE in "${SCA_LIST[@]}"; do
                            for TOTAL_KIMG in "${TOTAL_KIMG_list[@]}"; do
                                for NUM_STEPS in "${NUM_STEPS_list[@]}"; do
                                    if [[ " 1e-4 5e-4 " =~ " $LR " ]]; then
                                        LR_DECIMAL=$(python -c "print('%.4f' % float('$LR'))")
                                    elif [[ " 1e-5 5e-5 " =~ " $LR " ]]; then
                                        LR_DECIMAL=$(python -c "print('%.5f' % float('$LR'))")
                                    fi
                                    echo "AFS: $AFS, DATASET: $DATASET, FREEZE_NET: $FREEZE_NET, LR: $LR_DECIMAL, ITER: $ITER, RANK: $RANK, SCALE: $SCALE, TOTAL_KIMG: $TOTAL_KIMG, NUM_STEPS: $NUM_STEPS"
if [ $BEGIN_ID -lt 10 ]; then
    DIR_ID="0000${BEGIN_ID}"
elif [ $BEGIN_ID -lt 100 ]; then
    DIR_ID="000${BEGIN_ID}"
elif [ $BEGIN_ID -lt 1000 ]; then
    DIR_ID="00${BEGIN_ID}"
fi
BEGIN_ID=$((BEGIN_ID + 1))
M=3
SAMPLE_TEA="dpmpp"
MAX_ORDER=3
PREDICT_X0=True
LOW_ORDER_FINAL=True

USE_STEP_CONDITION=False
IS_SECOND_STAGE=False
SEEDS='50000-99999'
if [ "$AFS" = "True" ]; then
    NFE=$((NUM_STEPS - 2)) 
elif [ "$AFS" = "False" ]; then
    NFE=$((NUM_STEPS - 1)) 
fi

if [ "$DATASET" = "lsun_bedroom_ldm" ]; then
    BATCH_SIZE=8
    SCHEDULE_TYPE="discrete"
    SCHEDULE_RHO=1
elif [ "$DATASET" = "ms_coco" ]; then
    BATCH_SIZE=4
    SCHEDULE_TYPE="discrete"
    SCHEDULE_RHO=1
    SEEDS='0-29999'
    M=2
elif [ "$DATASET" = "imagenet64" ]; then
    SCHEDULE_TYPE="polynomial"
    SCHEDULE_RHO=7
    BATCH_SIZE=32
else
    SCHEDULE_TYPE="polynomial"
    SCHEDULE_RHO=7
    BATCH_SIZE=128
fi

if [ "$DATASET" = "lsun_bedroom_ldm" ]; then
torchrun --standalone --nproc_per_node=1 --master_port=12345 train.py \
--seed=1234 --dataset_name=${DATASET} --total_kimg=${TOTAL_KIMG} --batch=${BATCH_SIZE} --lr=${LR_DECIMAL} --num_steps=${NUM_STEPS} --M=${M} \
--afs=${AFS} --sampler_tea=${SAMPLE_TEA} --max_order=${MAX_ORDER} --predict_x0=${PREDICT_X0} --lower_order_final=${LOW_ORDER_FINAL} \
--schedule_type=${SCHEDULE_TYPE} --schedule_rho=${SCHEDULE_RHO} --use_step_condition=${USE_STEP_CONDITION} \
--is_second_stage=${IS_SECOND_STAGE} --freeze_net=${FREEZE_NET} --guidance_type='uncond' --rank=${RANK} --iter=${ITER} --scale=${SCALE}
elif  [ "$DATASET" = "ms_coco" ]; then
torchrun --standalone --nproc_per_node=1 --master_port=12345 train.py \
--seed=1234 --dataset_name=${DATASET} --total_kimg=${TOTAL_KIMG} --batch=${BATCH_SIZE} --lr=${LR_DECIMAL} --num_steps=${NUM_STEPS} --M=${M} \
--afs=${AFS} --sampler_tea=${SAMPLE_TEA} --max_order=${MAX_ORDER} --predict_x0=${PREDICT_X0} --lower_order_final=${LOW_ORDER_FINAL} \
--schedule_type=${SCHEDULE_TYPE} --schedule_rho=${SCHEDULE_RHO} --use_step_condition=${USE_STEP_CONDITION} \
--is_second_stage=${IS_SECOND_STAGE} --freeze_net=${FREEZE_NET} --guidance_type='cfg' --guidance_rate=1.0 --rank=${RANK} --iter=${ITER} --scale=${SCALE}
else
torchrun --standalone --nproc_per_node=1 --master_port=12345 train.py \
--seed=1234 --dataset_name=${DATASET} --total_kimg=${TOTAL_KIMG} --batch=${BATCH_SIZE} --lr=${LR_DECIMAL} --num_steps=${NUM_STEPS} --M=${M} \
--afs=${AFS} --sampler_tea=${SAMPLE_TEA} --max_order=${MAX_ORDER} --predict_x0=${PREDICT_X0} --lower_order_final=${LOW_ORDER_FINAL} \
--schedule_type=${SCHEDULE_TYPE} --schedule_rho=${SCHEDULE_RHO} --use_step_condition=${USE_STEP_CONDITION} \
--is_second_stage=${IS_SECOND_STAGE} --freeze_net=${FREEZE_NET} --rank=${RANK} --iter=${ITER} --scale=${SCALE}
fi

if [ $TOTAL_KIMG -lt 10 ]; then
    USE_KIMG="00000${TOTAL_KIMG}"
elif [ $TOTAL_KIMG -lt 100 ]; then
    USE_KIMG="0000${TOTAL_KIMG}"
elif [ $TOTAL_KIMG -lt 1000 ]; then
    USE_KIMG="000${TOTAL_KIMG}"
fi

if [ "$AFS" = "True" ]; then
    if [ "$DATASET" = "lsun_bedroom_ldm" ] || [ "$DATASET" = "ms_coco" ]; then
        MODEL_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-discrete1.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}-afs/network-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
        PLUGIN_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-discrete1.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}-afs/plugin-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
    else
        MODEL_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-poly7.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}-afs/network-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
        PLUGIN_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-poly7.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}-afs/plugin-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
    fi
elif [ "$AFS" = "False" ]; then
    if [ "$DATASET" = "lsun_bedroom_ldm" ] || [ "$DATASET" = "ms_coco" ]; then
        echo "DATASET is lsun_bedroom_ldm"
        MODEL_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-discrete1.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}/network-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
        PLUGIN_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-discrete1.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}/plugin-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
    else
        MODEL_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-poly7.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}/network-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
        PLUGIN_PATH="./exps/${DIR_ID}-${DATASET}-${NUM_STEPS}-${NFE}-${SAMPLE_TEA}-${MAX_ORDER}-poly7.0-${ITER}-${LR_DECIMAL}-${RANK}-${SCALE}/plugin-${FREEZE_NET}-${TOTAL_KIMG}-snapshot-best.pkl"
    fi
fi

SOLVER_LIST=( 'euler' )
for SOLVER in "${SOLVER_LIST[@]}"; do
    torchrun --standalone --nproc_per_node=1 --master_port=12345 sample.py \
    --dataset_name=${DATASET} --model_path=${MODEL_PATH} --plugin_path=${PLUGIN_PATH} --seeds=SEEDS --batch=${BATCH_SIZE} --solver=${SOLVER}
    # # ################# Evaluation #################
    if [ "$DATASET" = "cifar10" ]; then
        ref_path="./stats/cifar10-32x32.npz"
    elif [ "$DATASET" = "ffhq" ]; then
        ref_path="./stats/ffhq-64x64.npz"
    elif [ "$DATASET" = "imagenet64" ]; then
        ref_path="./stats/imagenet-64x64.npz"
    elif [ "$DATASET" = "lsun_bedroom_ldm" ]; then
        ref_path="./stats/lsun_bedroom-256x256.npz"
    elif [ "$DATASET" = "ms_coco" ]; then
        ref_path="./stats/ms_coco-512x512.npz"
    fi
    python fid.py calc --model_path=${MODEL_PATH} --images="./samples/${DATASET}/${SOLVER}_step${NUM_STEPS}_nfe${NFE}_best_${FREEZE_NET}" --ref=${ref_path}
done

done
done
done
done
done
done
done
done
done