# train_config
dataset_path: "apollo-research/roneneldan-TinyStories-tokenizer-gpt2"
dataset_trust_remote_code: true
streaming: true
is_dataset_tokenized: true
use_cached_activations: false
cached_activations_path: null
n_batches_in_buffer: 64
training_tokens: 2_000_000_000
finetuning_tokens: 0
store_batch_size_prompts: 16
train_batch_size_tokens: 8192
device: "cuda"
act_store_device: "cuda"
seed: 42
dtype: "float16"
adam_beta1: 0.9
adam_beta2: 0.999
l1_coefficient: 5.0
lp_norm: 1
l1_warm_up_steps: 0
lr: 0.0001
n_checkpoints: 10
checkpoint_path: "outputs/gpt2-base/checkpoints"
verbose: true
resume: false

# wandb
log_to_wandb: true
wandb_project: "sae_training"
run_name: "gpt2-base"
wandb_entity: null
wandb_log_frequency: 100
eval_every_n_wandb_logs: 10

# model_config
model_name: "tiny-stories-1M" #https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html
model_class_name: "HookedTransformer"
hook_name: "blocks.1.hook_mlp_out"
hook_eval: "NOT_IN_USE"
hook_layer: 1
hook_head_index: null
d_in: 64
d_sae: null # 768 * 4 = 3072
b_dec_init_method: "zeros"
expansion_factor: 4
activation_fn: "relu"
normalize_sae_decoder: false
noise_scale: 0.0
from_pretrained_path: null
apply_b_dec_to_input: false
decoder_orthogonal_init: false
decoder_heuristic_init: false
init_encoder_as_decoder_transpose: false



# model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
# model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
# hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
# hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
# hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
# hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
# dataset_path (str): A Hugging Face dataset path.
# dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
# streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
# is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
# context_size (int): The context size to use when generating activations on which to train the SAE.
# use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
# cached_activations_path (str, optional): The path to the cached activations.
# d_in (int): The input dimension of the SAE.
# d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
# b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
# expansion_factor (int): The expansion factor. Larger is better but more computationally expensive.
# activation_fn (str): The activation function to use. Relu is standard.
# normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
# noise_scale (float): Using noise to induce sparsity is supported but not recommended.
# from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
# apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
# decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
# decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
# init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
# n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
# training_tokens (int): The number of training tokens.
# finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
# store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
# train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
# normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
# device (str): The device to use. Usually cuda.
# act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
# seed (int): The seed to use.
# dtype (str): The data type to use.
# prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
# autocast (bool): Whether to use autocast during training. Saves vram.
# autocast_lm (bool): Whether to use autocast during activation fetching.
# compile_llm (bool): Whether to compile the LLM.
# llm_compilation_mode (str): The compilation mode to use for the LLM.
# compile_sae (bool): Whether to compile the SAE.
# sae_compilation_mode (str): The compilation mode to use for the SAE.
# train_batch_size_tokens (int): The batch size for training.
# adam_beta1 (float): The beta1 parameter for Adam.
# adam_beta2 (float): The beta2 parameter for Adam.
# mse_loss_normalization (str): The normalization to use for the MSE loss.
# l1_coefficient (float): The L1 coefficient.
# lp_norm (float): The Lp norm.
# scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
# l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
# lr (float): The learning rate.
# lr_scheduler_name (str): The name of the learning rate scheduler to use.
# lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
# lr_end (float): The end learning rate for the cosine annealing scheduler.
# lr_decay_steps (int): The number of decay steps for the learning rate.
# n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
# finetuning_method (str): The method to use for finetuning.
# use_ghost_grads (bool): Whether to use ghost gradients.
# feature_sampling_window (int): The feature sampling window.
# dead_feature_window (int): The dead feature window.
# dead_feature_threshold (float): The dead feature threshold.
# n_eval_batches (int): The number of evaluation batches.
# eval_batch_size_prompts (int): The batch size for evaluation.
# log_to_wandb (bool): Whether to log to Weights & Biases.
# log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
# log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
# wandb_project (str): The Weights & Biases project to log to.
# wandb_id (str): The Weights & Biases ID.
# run_name (str): The name of the run.
# wandb_entity (str): The Weights & Biases entity.
# wandb_log_frequency (int): The frequency to log to Weights & Biases.
# eval_every_n_wandb_logs (int): The frequency to evaluate.
# resume (bool): Whether to resume training.
# n_checkpoints (int): The number of checkpoints.
# checkpoint_path (str): The path to save checkpoints.
# verbose (bool): Whether to print verbose output.
# model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
# model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.