program: e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py
name: gpt2-e2e
method: grid
metric:
  name: val_loss
  goal: minimize
parameters:
  seed:
    values: [0]
  n_samples:
    values: [400_000]
  lr:
    values: [5e-4]
  loss:
    parameters:
      sparsity:
        parameters:
          coeff:
            values: [0.8, 2.5, 5]
      in_to_orig:
        parameters:
          hook_positions:
            values: [[blocks.11.hook_resid_pre]]
          total_coeff:
            values: [0.05, 0.1]
      out_to_orig:
        values: [null]
      out_to_in:
        parameters:
          coeff:
            values: [0.0]
      logits_kl:
        parameters:
          coeff:
            values: [0.5]
  saes:
    parameters:
      sae_positions:
        values:
          - blocks.10.hook_resid_pre
      dict_size_to_input_ratio:
        values: [60.0]

  train_data:
    parameters:
      dataset_name:
        values: [ANONYMIZED/Skylion007-openwebtext-tokenizer-gpt2]
      is_tokenized:
        values: [true]
      tokenizer_name:
        values: [gpt2]
      streaming:
        values: [true]
      split:
        values: [train]
      n_ctx:
        values: [1024]
  eval_data:
    parameters:
      dataset_name:
        values: [ANONYMIZED/Skylion007-openwebtext-tokenizer-gpt2]
      is_tokenized:
        values: [true]
      tokenizer_name:
        values: [gpt2]
      streaming:
        values: [true]
      split:
        values: [train]
      n_ctx:
        values: [1024]
command:
- ${env}
- ${interpreter}
- ${program}
- e2e_sae/scripts/train_tlens_saes/gpt2_e2e_recon.yaml