defaults:
  - data: inaturalist19
  - _self_

hydra:
  run:
    dir: ./artefacts/hydra_outputs

training: 
  num_epochs: 100
  use_cuda: True
  task: ImageClassification
  seed: 7
  validation_interval: 4
  hooks:
    in_place: True # actually true for distil.py, this only prevents stratified sampling
    single: False
    embedding_key: norm_pre
    fqn_prefix: blocks
    num_batches: 30
    batch_train: True
    optimal_steps: 600


model:
  name: deit_base_patch16_224
  pretrained: False
  path: ./training/models/deit_base_patch16_224.pth
  checkpoint_path: ./training/models/deit_gmp_pb50_zmxf938w.pt
  # checkpoint_path: ./artefacts/models/compressed_model_zl5zbxx6.pt
  kwargs:
    sparsity: 0.75
    p_budget: 0.5
    emb_fqn: norm_pre
    block_fqn_prefix: blocks
    layer_fqn_suffixes: [fc1, fc2]
    normalization:
      in_place: False
      fqn: None
      apply_before_idx: None
    block_group_indices:
      block_group_1: [0, 1, 2, 3]
      block_group_2: [4, 5, 6, 7]
      block_group_3: [8, 9, 10, 11]
    concat_axis: 1
    mask: False
    base_requires_grad: True


optimizer:
  name: AdamW
  label_smoothing: 0.11
  kwargs:
    lr: 1.0e-4
    weight_decay: 0.05
    betas: [0.9, 0.999]
    eps: 9.999e-9 # effectively this is 1e-8
    amsgrad: False
  scheduler:
    name: CosineAnnealingLR
    kwargs: 
      T_max: 0 # being overriden in main.py
      eta_min: 1.0e-5 # effectively this is 1e-6


sparsifier:
  name: rigl
  sparsity: 0.75
  global_pruning: True
  grown_weights_init: 0
  pruning_ratio: 0.05
  t_accel: 240
  initial_sparsity: 0.05
  final_sparsity: 0.5
  accelerated_sparsity: 0.30
  scheduler:
    delta_t: 50
    t_end_coeff: 0.75


wandb:
  ENABLE: True # Global switch for turning all wandb calls into no-ops, see ./src/package/utils/wandb_utils.py
  run_id: null # Keep null for fresh runs, otherwise, set to run_id to load ckpt
  name: ${model.name}-${sparsifier.name}-${optimizer.kwargs.lr}-inaturalist19-pb10
  project: SparseDecompositions2Share
  entity: cem1
  start_method: thread
  log_images: False # If True, log images when relevant to ML task


paths:
  artefacts: ./artefacts/
  logs: ${paths.artefacts}/logs/
  models: ${paths.artefacts}/models/

benchmarking:
  latency: False
  inference_profiling: False
  training_profiling: False
  param_budget: 0.750
  compressed: False
  compiled: False
  to_csr: False
  two_four: False
  log_name: memory_log_bs_${benchmarking.batch_size}
  embed_dim: 768
  batch_size: null
  fullgraph: True
  output_dir: ./bench/logs/a100-fast-forward
  device: cuda:0
  num_threads: null
  deepsparse: False
  min_run_time: 20
  fast_forward: False
  dtype: float16