program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py
name: gpt2_local
method: grid
metric:
  name: val_loss
  goal: minimize
parameters:
  seed:
    values: [0, 1]
  n_samples:
    values: [400_000]
  lr:
    values: [5e-4]
  loss:
    parameters:
      sparsity:
        parameters:
          coeff:
            values: [1, 3, 6, 10, 15, 20, 30]
      in_to_orig:
        values: [null]
      out_to_orig:
        values: [null]
      out_to_in:
        parameters:
          coeff:
            values: [1.0]
      logits_kl:
        values: [null]
  saes:
    parameters:
      sae_positions:
        values:
          - blocks.10.hook_resid_pre
      dict_size_to_input_ratio:
        values: [60.0]

  train_data:
    parameters:
      dataset_name:
        values: [ANONYMIZED/Skylion007-openwebtext-tokenizer-gpt2]
      is_tokenized:
        values: [true]
      tokenizer_name:
        values: [gpt2]
      streaming:
        values: [true]
      split:
        values: [train]
      n_ctx:
        values: [1024]
  eval_data:
    parameters:
      dataset_name:
        values: [ANONYMIZED/Skylion007-openwebtext-tokenizer-gpt2]
      is_tokenized:
        values: [true]
      tokenizer_name:
        values: [gpt2]
      streaming:
        values: [true]
      split:
        values: [train]
      n_ctx:
        values: [1024]
command:
- ${env}
- ${interpreter}
- ${program}
- e2e_sae/scripts/train_tlens_saes/gpt2_local.yaml