defaults:
  - _self_ # We default to values in this file if group settings below are null
  - dataset: mnist
  - model: mnist
  - rigl: ${model}
  - training: ${model}
  - compute: ${dataset}

experiment:
  comment: "_dynamic_ablation-${rigl.dynamic_ablation}" # Start with leading underscore
  name:
    "${model.name}_${dataset.name}_dense_alloc-${rigl.dense_allocation}_const_fan_in-${rigl.const_fan_in}\
    _min_sal-${rigl.min_salient_weights_per_neuron}\
    _conv_proj_dense-${rigl.keep_first_layer_dense}_mha-dense-${rigl.ignore_mha_layers}\
    ${experiment.comment}"
  resume_from_checkpoint: False
  run_id: null

paths:
  base: ${oc.env:BASE_PATH}
  data_folder: ${paths.base}/data
  artifacts: ${paths.base}/artifacts
  logs: ${paths.base}/logs
  checkpoints: ${paths.artifacts}/checkpoints
  graphs: ${paths.base}/graphs

rigl:
  # percentage of dense parameters allowed. if null, pruning will not
  # be used. must be on the interval (0, 1]
  dense_allocation: 0.1
  # number of steps between rigl pruning / regrow iterations
  delta: 800 # 4096 * 100 / 512 = 800 || 100 for cifar || x5 -> No change?
  # number of gradients to accumulate before scoring for rigl
  grad_accumulation_n: 8 # 4096 / 512 without simulated batch size, 1 otheriwse
  # alpha param for pruning for cosine decay in determining how many connections to prune/regrow
  alpha: 0.3 # 0.3 for vanilla rigl, 0.5 for ITOP -> Inital fraction of weights to update every pruner step
  # if 1, use random sparsity topo and remain static
  static_topo: 0
  # If True, use const_fan_in scheduler.
  const_fan_in: True
  # Define layer-wise sparsity distribution. Options include `uniform` & `erk`
  sparsity_distribution: erk
  # Power scale for ERK distribution
  erk_power_scale: 1.0
  # Use T_end from vanilla rigl, if False, explore new topologies until end of training
  use_t_end: True
  # Instructs scheduler to ignore linear layers and leave dense if True.
  ignore_linear_layers: False
  # Static ablation, if true, ablate filters at init with static mask.
  static_ablation: False
  # Static Filter abalation threshold criterion -> Not used for dynamic ablation
  filter_ablation_threshold: null # 1%
  # Dyanmic ablation will ablate filters dynamically during traing based on gradient information. Cannot be True if static_ablation is True
  dynamic_ablation: True
  # If False, min_salient_weights_per_neuron as a percentage will ablate based on percentage of dense fan-in. If True, use sparse fan in as percentage demoniator.
  use_sparse_const_fan_in_for_ablation: True
  # Param for minimum number of salient weights per neuron. If < this value, the neuron will be dynamically ablated.
  min_salient_weights_per_neuron: 0.3
  # If True, update weight initalization to account for sparse const fan in.
  use_sparse_initialization: True
  # Param to choose between sparse init methods -> options include "kaiming_uniform", "kaiming_normal", "sparse_torch", and "grad_flow_init"
  init_method_str: grad_flow_init
  # keep first layer dense
  keep_first_layer_dense: False
  # Value to initalize regrown weights to. Used for debugging. Use 0 for vanilla rigl methodology.
  initialize_grown_weights: 0.0
  # Ignore multi-head attention layers
  ignore_mha_layers: False

training:
  # quickly check a single pass",
  dry_run: False
  # batch size for training
  batch_size: 512 # 128 for cifar10, multiples of 256 for imagenet w/ scaled lr -> 512 with 2 v100s or 4 3090s -> 1024 with 4 v100s
  # Batch size to simulate by taking multiple steps per optimizer update (accumulate grads)
  simulated_batch_size: null
  # batch size for testing
  test_batch_size: 1000
  # number of epochs to train
  epochs: 103 # For imagenet, training will finish at ~ epoch 102.31, use 250 for cifar10 || imagenet x5 -> 515
  # random seed
  seed: 42
  # Log interval (number of mini-batch steps)
  log_interval: 100
  # For Saving the current Model
  save_model: True
  # max number of steps (will override epochs) -> Based on number of optimizer steps (ie., dataset len * epochs * simulated batch size)
  max_steps: 256000 # 4096 * 32000 / 512 = 256000 from rigl paper || x5 -> 4096*32000/1024 = 640000

  ## Optimization
  optimizer: sgd # "sgd", "adamw"
  # L2 Regularization for optimizer
  weight_decay: 0.0001 # 5.0e-4 -> cifar10   0.0001 -> imagenet
  # Momentum coefficient for SGD optimizer
  momentum: 0.9 # 0.9 -> rigl value
  # Label smoothing for cross entropy
  label_smoothing: 0.1 # 0.1 for imagenet, 0.0 for cifar
  # Betas for adam
  betas: [0.9, 0.999]
  # Gradient clipping, If null, no clipping is performed.
  clip_grad_norm: null
  #rmsprop smoothing alpha
  alpha: 0.9

  ## Scheduler
  scheduler: step_lr_with_warm_up # step_lr -> cifar # cosine_annealing_with_warm_up, step_lr_with_warm_up -> imagenet
  # Learning rate after warmup
  lr: 0.2 # 0.1 * batch size / 256 per original paper -> Use 1.6 for 4096 simulated batch size. 0.1 for ITOP
  # Learning rate for first epoch of warm up, only applies to schedulers with linear warm up
  init_lr: 1.0e-6
  # Number of epochs to warm up for linear warm ups
  warm_up_steps: 5 # 5 for imagenet | 0 for cifar10 | 25 for x5 imagenet
  # Learning rate step gamma
  gamma: 0.1 # 0.1 for imagenet, 0.2 for cifar
  # Step size to use in learning rate scheduler if StepLR is used. List[int] || int
  step_size: [30, 70, 90] # Imagenet -> [30,70,90] For cifar 10 -> every 30,000 mini-batch steps ~= 77 epochs | [150, 350, 450] for x5 | [60, 140, 180] for x2
  #use mixed precision
  use_amp: False

compute:
  # disables CUDA training",
  no_cuda: False
  cuda_kwargs:
    num_workers: ${ oc.decode:${oc.env:NUM_WORKERS} } # decode casts to int
    pin_memory: True
  # If True, use data-parallelization model for distributed training
  distributed: True
  # Number of CUDA devices to use for distributed training
  world_size: 4
  # Backend to use for distributed training, options: ["nccl", "gloo"]
  dist_backend: "nccl"
  # Use tensor 32 bit floats
  use_tf32: False

wandb:
  log_to_wandb: False
  project: sparsimony
  # project: condensed-rigl
  entity: mklasby
  # entity: condensed-sparsity
  start_method: thread
  log_images: False
  watch_model_grad_and_weights: False
  log_filter_stats: True

hydra:
  run:
    dir: ${paths.logs}

model:
  no_ablation_module_names: null

logging:
  level: info
