usage: train.py [-h] [--data-dir DIR] [--dataset NAME] [--train-split NAME]
                [--val-split NAME] [--train-num-samples N]
                [--val-num-samples N] [--dataset-download]
                [--class-map FILENAME] [--input-img-mode INPUT_IMG_MODE]
                [--input-key INPUT_KEY] [--target-key TARGET_KEY]
                [--dataset-trust-remote-code] [--model MODEL] [--pretrained]
                [--pretrained-path PRETRAINED_PATH]
                [--initial-checkpoint PATH] [--resume PATH] [--no-resume-opt]
                [--num-classes N] [--gp POOL] [--img-size N] [--in-chans N]
                [--input-size N N N] [--crop-pct N] [--mean MEAN [MEAN ...]]
                [--std STD [STD ...]] [--interpolation NAME] [-b N] [-vb N]
                [--channels-last] [--fuser FUSER] [--grad-accum-steps N]
                [--grad-checkpointing] [--fast-norm]
                [--model-kwargs [MODEL_KWARGS ...]]
                [--head-init-scale HEAD_INIT_SCALE]
                [--head-init-bias HEAD_INIT_BIAS]
                [--torchcompile-mode TORCHCOMPILE_MODE]
                [--torchscript | --torchcompile [TORCHCOMPILE]]
                [--device DEVICE] [--amp] [--amp-dtype AMP_DTYPE]
                [--amp-impl AMP_IMPL] [--model-dtype MODEL_DTYPE]
                [--no-ddp-bb] [--synchronize-step] [--local_rank LOCAL_RANK]
                [--device-modules DEVICE_MODULES [DEVICE_MODULES ...]]
                [--opt OPTIMIZER] [--opt-eps EPSILON]
                [--opt-betas BETA [BETA ...]] [--momentum M]
                [--weight-decay WEIGHT_DECAY] [--clip-grad NORM]
                [--clip-mode CLIP_MODE] [--layer-decay LAYER_DECAY]
                [--opt-kwargs [OPT_KWARGS ...]] [--sched SCHEDULER]
                [--sched-on-updates] [--lr LR] [--lr-base LR]
                [--lr-base-size DIV] [--lr-base-scale SCALE]
                [--lr-noise pct, pct [pct, pct ...]] [--lr-noise-pct PERCENT]
                [--lr-noise-std STDDEV] [--lr-cycle-mul MULT]
                [--lr-cycle-decay MULT] [--lr-cycle-limit N]
                [--lr-k-decay LR_K_DECAY] [--warmup-lr LR] [--min-lr LR]
                [--epochs N] [--epoch-repeats N] [--start-epoch N]
                [--decay-milestones MILESTONES [MILESTONES ...]]
                [--decay-epochs N] [--warmup-epochs N] [--warmup-prefix]
                [--cooldown-epochs N] [--patience-epochs N]
                [--decay-rate RATE] [--no-aug]
                [--train-crop-mode TRAIN_CROP_MODE] [--scale PCT [PCT ...]]
                [--ratio RATIO [RATIO ...]] [--hflip HFLIP] [--vflip VFLIP]
                [--color-jitter PCT] [--color-jitter-prob PCT]
                [--grayscale-prob PCT] [--gaussian-blur-prob PCT] [--aa NAME]
                [--aug-repeats AUG_REPEATS] [--aug-splits AUG_SPLITS]
                [--jsd-loss] [--bce-loss] [--bce-sum]
                [--bce-target-thresh BCE_TARGET_THRESH]
                [--bce-pos-weight BCE_POS_WEIGHT] [--reprob PCT]
                [--remode REMODE] [--recount RECOUNT] [--resplit]
                [--mixup MIXUP] [--cutmix CUTMIX]
                [--cutmix-minmax CUTMIX_MINMAX [CUTMIX_MINMAX ...]]
                [--mixup-prob MIXUP_PROB]
                [--mixup-switch-prob MIXUP_SWITCH_PROB]
                [--mixup-mode MIXUP_MODE] [--mixup-off-epoch N]
                [--smoothing SMOOTHING]
                [--train-interpolation TRAIN_INTERPOLATION] [--drop PCT]
                [--drop-connect PCT] [--drop-path PCT] [--drop-block PCT]
                [--bn-momentum BN_MOMENTUM] [--bn-eps BN_EPS] [--sync-bn]
                [--dist-bn DIST_BN] [--split-bn] [--model-ema]
                [--model-ema-force-cpu] [--model-ema-decay MODEL_EMA_DECAY]
                [--model-ema-warmup] [--seed S]
                [--worker-seeding WORKER_SEEDING] [--log-interval N]
                [--recovery-interval N] [--checkpoint-hist N] [-j N]
                [--save-images] [--pin-mem] [--no-prefetcher] [--output PATH]
                [--experiment NAME] [--eval-metric EVAL_METRIC] [--tta N]
                [--use-multi-epochs-loader] [--log-wandb]
                [--wandb-project WANDB_PROJECT]
                [--wandb-tags WANDB_TAGS [WANDB_TAGS ...]]
                [--wandb-resume-id ID] [--naflex-loader]
                [--naflex-train-seq-lens NAFLEX_TRAIN_SEQ_LENS [NAFLEX_TRAIN_SEQ_LENS ...]]
                [--naflex-max-seq-len NAFLEX_MAX_SEQ_LEN]
                [--naflex-patch-sizes NAFLEX_PATCH_SIZES [NAFLEX_PATCH_SIZES ...]]
                [--naflex-patch-size-probs NAFLEX_PATCH_SIZE_PROBS [NAFLEX_PATCH_SIZE_PROBS ...]]
                [--naflex-loss-scale NAFLEX_LOSS_SCALE]
                [DIR]

PyTorch ImageNet Training

positional arguments:
  DIR                   path to dataset (positional is *deprecated*, use
                        --data-dir)

options:
  -h, --help            show this help message and exit

Dataset parameters:
  --data-dir DIR        path to dataset (root dir)
  --dataset NAME        dataset type + name ("<type>/<name>") (default:
                        ImageFolder or ImageTar if empty)
  --train-split NAME    dataset train split (default: train)
  --val-split NAME      dataset validation split (default: validation)
  --train-num-samples N
                        Manually specify num samples in train split, for
                        IterableDatasets.
  --val-num-samples N   Manually specify num samples in validation split, for
                        IterableDatasets.
  --dataset-download    Allow download of dataset for torch/ and tfds/
                        datasets that support it.
  --class-map FILENAME  path to class to idx mapping file (default: "")
  --input-img-mode INPUT_IMG_MODE
                        Dataset image conversion mode for input images.
  --input-key INPUT_KEY
                        Dataset key for input images.
  --target-key TARGET_KEY
                        Dataset key for target labels.
  --dataset-trust-remote-code
                        Allow huggingface dataset import to execute code
                        downloaded from the dataset's repo.

Model parameters:
  --model MODEL         Name of model to train (default: "resnet50")
  --pretrained          Start with pretrained version of specified network (if
                        avail)
  --pretrained-path PRETRAINED_PATH ###ZJ
                        Load this checkpoint as if they were the pretrained
                        weights (with adaptation).
  --initial-checkpoint PATH
                        Load this checkpoint into model after initialization
                        (default: none)
  --resume PATH         Resume full model and optimizer state from checkpoint
                        (default: none)
  --no-resume-opt       prevent resume of optimizer state when resuming model
  --num-classes N       number of label classes (Model default if None)
  --gp POOL             Global pool type, one of (fast, avg, max, avgmax,
                        avgmaxc). Model default if None.
  --img-size N          Image size (default: None => model default)
  --in-chans N          Image input channels (default: None => 3)
  --input-size N N N    Input all image dimensions (d h w, e.g. --input-size 3
                        224 224), uses model default if empty
  --crop-pct N          Input image center crop percent (for validation only)
  --mean MEAN [MEAN ...]
                        Override mean pixel value of dataset
  --std STD [STD ...]   Override std deviation of dataset
  --interpolation NAME  Image resize interpolation type (overrides model)
  -b N, --batch-size N  Input batch size for training (default: 128)
  -vb N, --validation-batch-size N
                        Validation batch size override (default: None)
  --channels-last       Use channels_last memory layout
  --fuser FUSER         Select jit fuser. One of ('', 'te', 'old', 'nvfuser')
  --grad-accum-steps N  The number of steps to accumulate gradients (default:
                        1)
  --grad-checkpointing  Enable gradient checkpointing through model
                        blocks/stages
  --fast-norm           enable experimental fast-norm
  --model-kwargs [MODEL_KWARGS ...]
  --head-init-scale HEAD_INIT_SCALE
                        Head initialization scale
  --head-init-bias HEAD_INIT_BIAS
                        Head initialization bias value
  --torchcompile-mode TORCHCOMPILE_MODE
                        torch.compile mode (default: None).
  --torchscript         torch.jit.script the full model
  --torchcompile [TORCHCOMPILE]
                        Enable compilation w/ specified backend (default:
                        inductor).

Device parameters:
  --device DEVICE       Device (accelerator) to use.
  --amp                 use NVIDIA Apex AMP or Native AMP for mixed precision
                        training
  --amp-dtype AMP_DTYPE
                        lower precision AMP dtype (default: float16)
  --amp-impl AMP_IMPL   AMP impl to use, "native" or "apex" (default: native)
  --model-dtype MODEL_DTYPE
                        Model dtype override (non-AMP) (default: float32)
  --no-ddp-bb           Force broadcast buffers for native DDP to off.
  --synchronize-step    torch.cuda.synchronize() end of each step
  --local_rank LOCAL_RANK
  --device-modules DEVICE_MODULES [DEVICE_MODULES ...]
                        Python imports for device backend modules.

Optimizer parameters:
  --opt OPTIMIZER       Optimizer (default: "sgd")
  --opt-eps EPSILON     Optimizer Epsilon (default: None, use opt default)
  --opt-betas BETA [BETA ...]
                        Optimizer Betas (default: None, use opt default)
  --momentum M          Optimizer momentum (default: 0.9)
  --weight-decay WEIGHT_DECAY
                        weight decay (default: 2e-5)
  --clip-grad NORM      Clip gradient norm (default: None, no clipping)
  --clip-mode CLIP_MODE
                        Gradient clipping mode. One of ("norm", "value",
                        "agc")
  --layer-decay LAYER_DECAY
                        layer-wise learning rate decay (default: None)
  --opt-kwargs [OPT_KWARGS ...]

Learning rate schedule parameters:
  --sched SCHEDULER     LR scheduler (default: "cosine"
  --sched-on-updates    Apply LR scheduler step on update instead of epoch
                        end.
  --lr LR               learning rate, overrides lr-base if set (default:
                        None)
  --lr-base LR          base learning rate: lr = lr_base * global_batch_size /
                        base_size
  --lr-base-size DIV    base learning rate batch size (divisor, default: 256).
  --lr-base-scale SCALE
                        base learning rate vs batch_size scaling ("linear",
                        "sqrt", based on opt if empty)
  --lr-noise pct, pct [pct, pct ...]
                        learning rate noise on/off epoch percentages
  --lr-noise-pct PERCENT
                        learning rate noise limit percent (default: 0.67)
  --lr-noise-std STDDEV
                        learning rate noise std-dev (default: 1.0)
  --lr-cycle-mul MULT   learning rate cycle len multiplier (default: 1.0)
  --lr-cycle-decay MULT
                        amount to decay each learning rate cycle (default:
                        0.5)
  --lr-cycle-limit N    learning rate cycle limit, cycles enabled if > 1
  --lr-k-decay LR_K_DECAY
                        learning rate k-decay for cosine/poly (default: 1.0)
  --warmup-lr LR        warmup learning rate (default: 1e-5)
  --min-lr LR           lower lr bound for cyclic schedulers that hit 0
                        (default: 0)
  --epochs N            number of epochs to train (default: 300)
  --epoch-repeats N     epoch repeat multiplier (number of times to repeat
                        dataset epoch per train epoch).
  --start-epoch N       manual epoch number (useful on restarts)
  --decay-milestones MILESTONES [MILESTONES ...]
                        list of decay epoch indices for multistep lr. must be
                        increasing
  --decay-epochs N      epoch interval to decay LR
  --warmup-epochs N     epochs to warmup LR, if scheduler supports
  --warmup-prefix       Exclude warmup period from decay schedule.
  --cooldown-epochs N   epochs to cooldown LR at min_lr, after cyclic schedule
                        ends
  --patience-epochs N   patience epochs for Plateau LR scheduler (default: 10)
  --decay-rate RATE, --dr RATE
                        LR decay rate (default: 0.1)

Augmentation and regularization parameters:
  --no-aug              Disable all training augmentation, override other
                        train aug args
  --train-crop-mode TRAIN_CROP_MODE
                        Crop-mode in train
  --scale PCT [PCT ...]
                        Random resize scale (default: 0.08 1.0)
  --ratio RATIO [RATIO ...]
                        Random resize aspect ratio (default: 0.75 1.33)
  --hflip HFLIP         Horizontal flip training aug probability
  --vflip VFLIP         Vertical flip training aug probability
  --color-jitter PCT    Color jitter factor (default: 0.4)
  --color-jitter-prob PCT
                        Probability of applying any color jitter.
  --grayscale-prob PCT  Probability of applying random grayscale conversion.
  --gaussian-blur-prob PCT
                        Probability of applying gaussian blur.
  --aa NAME             Use AutoAugment policy. "v0" or "original". (default:
                        None)
  --aug-repeats AUG_REPEATS
                        Number of augmentation repetitions (distributed
                        training only) (default: 0)
  --aug-splits AUG_SPLITS
                        Number of augmentation splits (default: 0, valid: 0 or
                        >=2)
  --jsd-loss            Enable Jensen-Shannon Divergence + CE loss. Use with
                        `--aug-splits`.
  --bce-loss            Enable BCE loss w/ Mixup/CutMix use.
  --bce-sum             Sum over classes when using BCE loss.
  --bce-target-thresh BCE_TARGET_THRESH
                        Threshold for binarizing softened BCE targets
                        (default: None, disabled).
  --bce-pos-weight BCE_POS_WEIGHT
                        Positive weighting for BCE loss.
  --reprob PCT          Random erase prob (default: 0.)
  --remode REMODE       Random erase mode (default: "pixel")
  --recount RECOUNT     Random erase count (default: 1)
  --resplit             Do not random erase first (clean) augmentation split
  --mixup MIXUP         mixup alpha, mixup enabled if > 0. (default: 0.)
  --cutmix CUTMIX       cutmix alpha, cutmix enabled if > 0. (default: 0.)
  --cutmix-minmax CUTMIX_MINMAX [CUTMIX_MINMAX ...]
                        cutmix min/max ratio, overrides alpha and enables
                        cutmix if set (default: None)
  --mixup-prob MIXUP_PROB
                        Probability of performing mixup or cutmix when
                        either/both is enabled
  --mixup-switch-prob MIXUP_SWITCH_PROB
                        Probability of switching to cutmix when both mixup and
                        cutmix enabled
  --mixup-mode MIXUP_MODE
                        How to apply mixup/cutmix params. Per "batch", "pair",
                        or "elem"
  --mixup-off-epoch N   Turn off mixup after this epoch, disabled if 0
                        (default: 0)
  --smoothing SMOOTHING
                        Label smoothing (default: 0.1)
  --train-interpolation TRAIN_INTERPOLATION
                        Training interpolation (random, bilinear, bicubic
                        default: "random")
  --drop PCT            Dropout rate (default: 0.)
  --drop-connect PCT    Drop connect rate, DEPRECATED, use drop-path (default:
                        None)
  --drop-path PCT       Drop path rate (default: None)
  --drop-block PCT      Drop block rate (default: None)

Batch norm parameters:
  Only works with gen_efficientnet based models currently.

  --bn-momentum BN_MOMENTUM
                        BatchNorm momentum override (if not None)
  --bn-eps BN_EPS       BatchNorm epsilon override (if not None)
  --sync-bn             Enable NVIDIA Apex or Torch synchronized BatchNorm.
  --dist-bn DIST_BN     Distribute BatchNorm stats between nodes after each
                        epoch ("broadcast", "reduce", or "")
  --split-bn            Enable separate BN layers per augmentation split.

Model exponential moving average parameters:
  --model-ema           Enable tracking moving average of model weights.
  --model-ema-force-cpu
                        Force ema to be tracked on CPU, rank=0 node only.
                        Disables EMA validation.
  --model-ema-decay MODEL_EMA_DECAY
                        Decay factor for model weights moving average
                        (default: 0.9998)
  --model-ema-warmup    Enable warmup for model EMA decay.

Miscellaneous parameters:
  --seed S              random seed (default: 42)
  --worker-seeding WORKER_SEEDING
                        worker seed mode (default: all)
  --log-interval N      how many batches to wait before logging training
                        status
  --recovery-interval N
                        how many batches to wait before writing recovery
                        checkpoint
  --checkpoint-hist N   number of checkpoints to keep (default: 10)
  -j N, --workers N     how many training processes to use (default: 4)
  --save-images         save images of input batches every log interval for
                        debugging
  --pin-mem             Pin CPU memory in DataLoader for more efficient
                        (sometimes) transfer to GPU.
  --no-prefetcher       disable fast prefetcher
  --output PATH         path to output folder (default: none, current dir)
  --experiment NAME     name of train experiment, name of sub-folder for
                        output
  --eval-metric EVAL_METRIC
                        Best metric (default: "top1"
  --tta N               Test/inference time augmentation (oversampling)
                        factor. 0=None (default: 0)
  --use-multi-epochs-loader
                        use the multi-epochs-loader to save time at the
                        beginning of every epoch
  --log-wandb           log training and validation metrics to wandb
  --wandb-project WANDB_PROJECT
                        wandb project name
  --wandb-tags WANDB_TAGS [WANDB_TAGS ...]
                        wandb tags
  --wandb-resume-id ID  If resuming a run, the id of the run in wandb
  --naflex-loader       Use NaFlex loader (Requires NaFlex compatible model)
  --naflex-train-seq-lens NAFLEX_TRAIN_SEQ_LENS [NAFLEX_TRAIN_SEQ_LENS ...]
                        Sequence lengths to use for NaFlex loader
  --naflex-max-seq-len NAFLEX_MAX_SEQ_LEN
                        Fixed maximum sequence length for NaFlex loader
                        (validation)
  --naflex-patch-sizes NAFLEX_PATCH_SIZES [NAFLEX_PATCH_SIZES ...]
                        List of patch sizes for variable patch size training
                        (e.g., 8 12 16 24 32)
  --naflex-patch-size-probs NAFLEX_PATCH_SIZE_PROBS [NAFLEX_PATCH_SIZE_PROBS ...]
                        Probabilities for each patch size (must sum to 1.0,
                        uniform if not specified)
  --naflex-loss-scale NAFLEX_LOSS_SCALE
                        Scale loss (gradient) by batch_size ("none", "sqrt",
                        or "linear")
