program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py
name: logits_kl
method: grid
metric:
  name: val_loss
  goal: minimize
parameters:
  n_samples:
    values: [450_000]
  lr:
    values: [1e-2, 5e-3, 1e-3]
  loss:
    parameters:
      sparsity:
        parameters:
          coeff:
            values: [50, 30, 20]
  saes:
    parameters:
      sae_positions:
        values: [blocks.1.hook_resid_pre]
      dict_size_to_input_ratio:
        values: [5.0, 10.0]

command:
- ${env}
- ${interpreter}
- ${program}
- e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml