program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py
name: tinystories_1m_local
method: grid
metric:
  name: val_loss
  goal: minimize
parameters:
  n_samples:
    values: [400_000]
  lr:
    values: [1e-3]
  loss:
    parameters:
      sparsity:
        parameters:
          coeff:
            values: [0.001, 0.005, 0.008, 0.01, 0.02, 0.05]
  saes:
    parameters:
      sae_positions:
        values: [blocks.4.hook_resid_pre]
      dict_size_to_input_ratio:
        values: [5, 20, 60, 100]

command:
- ${env}
- ${interpreter}
- ${program}
- e2e_sae/scripts/train_tlens_saes/tinystories_1M_local.yaml