# Scalar Affine + Low-Rank adapter (best performing architecture)
# f(x) = scale * x + UV^T x + bias

experiment_name: "scalar_affine_plus_low_rank"
seed: 42

model:
  name: "meta-llama/Meta-Llama-3.1-8B-Instruct"
  device_map: "auto"
  dtype: "bfloat16"
  enable_gradient_checkpointing: true

data:
  labels_file: "data/goodfire_8b_l19_labels.json"  # See data/README.md for setup instructions
  batch_size: 80
  shuffle: true
  num_workers: 4
  eos_token: "<|eot_id|>"
  strip_labels: True

projection:
  type: "scalar_affine_plus_low_rank"
  normalize_input: true
  init_scale: 5.0
  low_rank_rank: 64  # Rank-64 achieves best validation loss in paper
  low_rank_init_factor: 0.01

soft_prompt:
  template: |-
    <|begin_of_text|><|start_header_id|>user<|end_header_id|>

    What is the meaning of "<|reserved_special_token_0|>"?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

    The meaning of "<|reserved_special_token_0|>" is "

training:
  learning_rate: 0.01
  optimizer_type: "adamw"
  weight_decay: 0.01
  num_epochs: 2
  gradient_accumulation_steps: 1
  gradient_clip_norm: 0.5
  label_smoothing: 0.0
  max_loss: 100.0
  scheduler_type: "cosine"
  warmup_steps: 10
  validation_every_n_steps: 50
  val_fraction: 0.1
  checkpoint_every_n_steps: 100
  checkpoint_dir: "./checkpoints"
  save_final_checkpoint: true

logging:
  use_wandb: true
  wandb_project: "selfie-adapter-training"
  log_every_n_steps: 1
  log_sample_generations: 5
  log_generations_every_n_steps: 100
  log_singular_values_every_n_steps: 50
