hydra:
  run:
    dir: ./artefacts/hydra_outputs

training: 
  num_epochs: 40
  use_cuda: True
  task: ImageClassification
  seed: 7
  validation_interval: 2
  hooks:
    in_place: True
    single: False
    embedding_key: embed_tokens
    fqn_prefix: layers
    num_batches: 20
    batch_train: True
    optimal_steps: 800

model:
  name: google/gemma-2-2b-it
  save_name: gemma2_2b
  pretrained: True
  path: ./ # ${oc.env:SLURM_TMPDIR}/models/deit_base_patch16_224.pth
  checkpoint_path: ./
  kwargs:
    model_precision: 16
    mlp_precision: 16
    p_budget: 0.5
    sparsity: 0.75
    structured_sparsity: False
    emb_fqn: patch_embed
    block_fqn_prefix: layers
    layer_fqn_suffixes: [up_proj, down_proj, gate_proj]
    normalization:
      in_place: False
      fqn: None
      apply_before_idx: None
    block_group_indices:
     block_group_1: [0, 1, 2, 3]
     block_group_2: [4, 5, 6, 7]
     block_group_3: [8, 9, 10, 11, 12]
     block_group_4: [13, 14, 15, 16, 17]
     block_group_5: [18, 19, 20, 21]
     block_group_6: [22, 23, 24, 25]
    concat_axis: 1
    mask: True
    base_requires_grad: True

optimizer:
  name: AdamW
  label_smoothing: 0.11
  kwargs:
    lr: 4.0e-4 # reduced from 3.125e-5 for better numerical stability
    weight_decay: 0.01
    betas: [0.9, 0.999]
    eps: 1.0e-6 # increased for better numerical stability
    amsgrad: False
  scheduler:
    name: CosineAnnealingLR
    kwargs: 
      T_max: 800 # being overriden in main.py
      eta_min: 4.0e-5  # effectively this is 1e-6

sparsifier:
  name: gradual
  sparsity: 0.75
  global_pruning: True
  grown_weights_init: 0
  pruning_ratio: 0.05
  t_accel: 200
  initial_sparsity: 0.1
  final_sparsity: 0.75
  accelerated_sparsity: 0.5
  scheduler:
    delta_t: 50
    t_end_coeff: 0.75

# TODO - check clip grad norm
data: 
  name: redpajama
  batch_size: 8192
  num_workers: 2

wandb:
  ENABLE: True # Global switch for turning all wandb calls into no-ops, see ./src/package/utils/wandb_utils.py
  run_id: null # Keep null for fresh runs, otherwise, set to run_id to load ckpt
  name: ${model.name}-${optimizer.kwargs.lr}-pb${model.kwargs.p_budget}-global-sp
  project: SparseDecompositions2Share
  entity: cem1
  start_method: thread
  log_images: False # If True, log images when relevant to ML task

paths:
  data: 
    directory: ./data/ImageNet/ # ${oc.env:SLURM_TMPDIR}
    train: train
    val: val
  artefacts: ./artefacts/
  logs: ${paths.artefacts}/logs/
  models: ${paths.artefacts}/models/
  
benchmarking:
  latency: False
  inference_profiling: True
  training_profiling: False
  param_budget: 0.90
  compressed: True
  compiled: False
