#!/usr/bin/env bash
set -euxo pipefail

DATA_ROOT="${DATA_ROOT:-./data}"
DATASET="cifar100st"

EPOCHS=120
WORKERS=2
BATCH_SIZE=256

NUM_USERS=32
LOCAL_BS=128
LOCAL_EP=4

K=1
I=1

BASE_LR=0.065
MIN_LR=0.002
WARMUP_ROUNDS=20
MOMENTUM=0.9
WEIGHT_DECAY=5e-4
GRAD_CLIP=5.0

LAMDA=1.0
BETA_Y=0.035
GAMMA_X=0.85
GAMMA_Y=0.12
Y_CLIP=8.0

LABEL_SMOOTHING=0.0
MIXUP_ALPHA=0.04
MIXUP_START_ROUND=85
MIXUP_FULL_ROUND=110

SERVER_REHEARSAL_STEPS=6
SERVER_REHEARSAL_LR=0.004
SERVER_REHEARSAL_START_ROUND=20
SERVER_LABEL_SMOOTHING=0.0

EVAL_NUM_CLIENTS=32

# lower start
PRETRAIN_EPOCHS=1
PRETRAIN_LR=0.0011

SEED=40

WANDB_PROJECT="avg"
WANDB_ENTITY="hq1351-wayne-state-university"

MODEL_NAME="resnet110"

python3 -u main1.py \
  --dataset "${DATASET}" \
  --model_name "${MODEL_NAME}" \
  --data_root "${DATA_ROOT}" \
  --epochs "${EPOCHS}" \
  --workers "${WORKERS}" \
  --batch_size "${BATCH_SIZE}" \
  --num_users "${NUM_USERS}" \
  --local_bs "${LOCAL_BS}" \
  --local_ep "${LOCAL_EP}" \
  --K "${K}" \
  --I "${I}" \
  --base_lr "${BASE_LR}" \
  --min_lr "${MIN_LR}" \
  --warmup_rounds "${WARMUP_ROUNDS}" \
  --momentum "${MOMENTUM}" \
  --weight_decay "${WEIGHT_DECAY}" \
  --grad_clip "${GRAD_CLIP}" \
  --lamda "${LAMDA}" \
  --beta_y "${BETA_Y}" \
  --gamma_x "${GAMMA_X}" \
  --gamma_y "${GAMMA_Y}" \
  --y_clip "${Y_CLIP}" \
  --label_smoothing "${LABEL_SMOOTHING}" \
  --mixup_alpha "${MIXUP_ALPHA}" \
  --mixup_start_round "${MIXUP_START_ROUND}" \
  --mixup_full_round "${MIXUP_FULL_ROUND}" \
  --server_rehearsal_steps "${SERVER_REHEARSAL_STEPS}" \
  --server_rehearsal_lr "${SERVER_REHEARSAL_LR}" \
  --server_rehearsal_start_round "${SERVER_REHEARSAL_START_ROUND}" \
  --server_label_smoothing "${SERVER_LABEL_SMOOTHING}" \
  --eval_num_clients "${EVAL_NUM_CLIENTS}" \
  --pretrain_epochs "${PRETRAIN_EPOCHS}" \
  --pretrain_lr "${PRETRAIN_LR}" \
  --random_seed "${SEED}" \
  --wandb \
  --wandb_project "${WANDB_PROJECT}" \
  --wandb_entity "${WANDB_ENTITY}"