defaults:
  - ../global_params@_here_
  - _self_

# PATH PARAMS (SET THESE!)
exp_name: null # experiment name used for assigning run to wandb projects and saving run path
run_name: null # run name used for naming run in wandb and saving run path

# ARCHITECTURE PARAMS
# model defaults: resnet18 (cifar10), resnet34 (cifar100)
model: resnet18 # [vgg16, resnet18, resnet18_gn, resnet_tiny, resnet_tiny_gn, resnet34, resnet34_gn]
dataset: cifar10 # [cifar10, cifar100, mnist, tiny_imagenet, cnn]

# TRAINING PARAMS
epochs: 200
extra_per_epoch: 0 # used in base runs to match the dataset size in upsampled runs, if null, defaults to normal dataset size
repeat_random_extra: 1 
eval_model_every_k_epochs: 1 # eval and save model every k, 0 means only at the end
model_width: 64 # num conv filters on 1st layer for resnets
batch_size: 128
lr_max: 0.1
lr_decay: 0.1 # 1 means no decay, 0.1 means cut by 1 order of magnitude every instance of interval
lr_epoch_interval: [100,150] # int/list, epochs to drop the learning rate at the end of
valset: true # if false, the data loader will return None for valset and use all data for the trainset (this is used when we want to generate syn data/offline data augm)
str_augm: none # [ta,ra,aa,none]
augm: true # basic augmentation (crop/flip)
l2_reg: 0.0005
droprate: 0.0
loss: ce # currently only cross-entropy is supported
shuffle: true
momentum: 0.9
noise_percent: 0 # label noise percentage

# UPWEIGHT LOSS PARAMS
loss_upweight_idx_path: null
loss_upweight_weight: 2
error_barrier_epoch: null # shuffle the loader at this epoch for computing the error barrier wrt another run
error_barrier_reroll: 400 # number of times to reroll

# SAM PARAMS
sam_rho: 0.0 # 0 means no SAM (0.1 for base SAM)
sam_no_grad_norm: false

# STAT COLLECTION PARAMS
get_fs_stats: false # collect forget score statistics
get_online_fs_stats: true 
get_sharpness_stats: false # collect sharpness statistics
get_grad_stats: false # get grad norm stats

# UPSAMPLE (US) PARAMS
us: false
us_type: real # [real, syn]
us_idx_path: null # if null, will use whole syn dataset if syn and do nothing if real
us_syn_dataset: null # includes dif,ta,ra,aa dataset types
us_syn_only: false
us_syn_100_100: false

# AUTO_UPSAMPLE (AUS) PARAMS
aus: false
upweight_loss_instead: false # upweight the loss weight of selected examples in retraining instead of adding them to the dataset
aus_epochs: 10 # num epochs to train for before auto-upsampling
aus_avg_epochs: false
aus_weight: 2
aus_clusters: 10
aus_method: clustering # [clustering, threshold, quantile]
aus_score: grad_norm # [grad_norm, cluster_size, forget_score, conf_score, margin, error, acc]
aus_score_threshold: 0 # set it when using aus_method = threshold
aus_score_quantile_range: [0.5, 1] # set the quantile range when using aus_method = quantile
aus_quantile_per_class: true 
aus_cluster_by: activations # [activations,logits]
aus_cluster_range: [3,5]
aus_train_after: true # set to false if you just want to find upsample indices
rewind_to_start: true # set model weight after aus to epoch 0
epochs_after_aus: full # [full, remaining] 
aus_shuffle_after: false

# META PARAMS
seed: 0
gpu: 0
debug: false # if true: logs, wandb, and checkpoints are disabled

model_ckpt_path: null # checkpoint to start training from
