hydra:
  run:
    dir: ./artefacts/hydra_outputs

training: 
  num_epochs: 40
  use_cuda: True
  task: ImageClassification
  seed: 7
  validation_interval: 2
  hooks:
    in_place: True
    single: False
    embedding_key: embed_tokens
    fqn_prefix: layers
    num_batches: 20
    batch_train: True
    optimal_steps: 400

model:
  name: huggyllama/llama-7b
  save_name: Llama7B
  pretrained: True
  path: ./
  checkpoint_path: ./
  kwargs:
    model_precision: 32
    mlp_precision: 32
    p_budget: 0.05
    sparsity: 0.75
    structured_sparsity: False
    emb_fqn: patch_embed
    block_fqn_prefix: layers
    layer_fqn_suffixes: [up_proj, down_proj, gate_proj]
    normalization:
      in_place: False
      fqn: None
      apply_before_idx: None
    block_group_indices:
      block_group_1: [0, 1, 2, 3]
      block_group_2: [4, 5, 6, 7, 8, 9]
      block_group_3: [10, 11, 12, 13, 14, 15]
      block_group_4: [16, 17, 18, 19, 20, 21]
      block_group_5: [22, 23, 24, 25, 26, 27]
      block_group_6: [28, 29, 30, 31]
    concat_axis: 1
    mask: True
    base_requires_grad: True

optimizer:
  name: AdamW
  label_smoothing: 0.11
  kwargs:
    lr: 2.0e-6
    weight_decay: 0.1
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False
  scheduler:
    name: CosineAnnealingLR
    kwargs: 
      T_max: 800 # being overriden in main.py
      eta_min: 2.0e-7  # effectively this is 1e-6

sparsifier:
  name: gradual
  sparsity: 0.75
  global_pruning: True
  grown_weights_init: 0
  pruning_ratio: 0.1
  t_accel: 200
  initial_sparsity: 0.05
  final_sparsity: 0.75
  accelerated_sparsity: 0.50
  scheduler:
    delta_t: 50
    t_end_coeff: 0.75

# TODO - check clip grad norm
data:
  name: redpajama
  batch_size: 8192
  num_workers: 2

wandb:
  ENABLE: True # Global switch for turning all wandb calls into no-ops, see ./src/package/utils/wandb_utils.py
  run_id: null # Keep null for fresh runs, otherwise, set to run_id to load ckpt
  name: ${model.name}-gmp-${optimizer.kwargs.lr}-pb355
  project: SparseDecompositions2Share
  entity: cem1
  start_method: thread
  log_images: False # If True, log images when relevant to ML task

paths:
  data: 
    directory: ./data/ImageNet/ # ${oc.env:SLURM_TMPDIR}
    train: train
    val: val
  artefacts: ./artefacts/
  logs: ${paths.artefacts}/logs/
  models: ${paths.artefacts}/models/
  
benchmarking:
  latency: False
  inference_profiling: True
  training_profiling: False
  param_budget: 0.90
  compressed: True
  compiled: False
