# Transcoder training sample code

"""
This sample script can be used to train a transcoder on a model of your choice.
This code, along with the transcoder training code more generally, was largely
    adapted from an older version of Joseph Bloom's SAE training repo, the latest
    version of which can be found at https://github.com/jbloomAus/SAELens.
Most of the parameters given here are the same as the SAE training parameters
    listed at https://jbloomaus.github.io/SAELens/training_saes/.
Transcoder-specific parameters are marked as such in comments.

"""

import torch
import os 
import sys
import numpy as np

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.utils import LMSparseAutoencoderSessionloader
from sae_training.train_sae_on_language_model import train_sae_on_language_model

lr = 0.0002 # learning rate
l1_coeffs = [float(x) for x in sys.argv[1:]]#[0.000008, 0.0000085] # l1 sparsity regularization coefficient
print(l1_coeffs)
layer=15

for l1_coeff in l1_coeffs:
    cfg = LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distibuion)
    
        # "hook_point" is the TransformerLens HookPoint representing
        #    the input activations to the transcoder that we want to train on.
        # Here, "ln2.hook_normalized" refers to the activations after the
        #    pre-MLP LayerNorm -- that is, the inputs to the MLP.
        # You might alternatively prefer to train on "blocks.8.hook_resid_mid",
        #    which corresponds to the input to the pre-MLP LayerNorm.
        hook_point = f"blocks.{layer}.ln2.hook_normalized",
        hook_point_layer = layer,
        d_in = 1024,
        dataset_path = "Skylion007/openwebtext",
        is_dataset_tokenized=False,
        model_name='pythia-410m',
    
        # Transcoder-specific parameters.
        is_transcoder = False, # We're not training a transcoder here.
        
        # SAE Parameters
        expansion_factor = 32,
        b_dec_init_method = "mean",
        
        # Training Parameters
        lr = lr,
        l1_coefficient = l1_coeff,
        lr_scheduler_name="constantwithwarmup",
        train_batch_size = 4096,
        context_size = 128,
        lr_warm_up_steps=6050,
        
        # Activation Store Parameters
        n_batches_in_buffer = 128,
        total_training_tokens = 1_000_000 * 60,
        store_batch_size = 32,
        
        # Dead Neurons and Sparsity
        use_ghost_grads=False,
        feature_sampling_method = None,#'anthropic',#None,
        feature_sampling_window = 1000,
        resample_batches=1028,
        dead_feature_window=5000,
        dead_feature_threshold = 1e-6,#1e-8,
    
        # WANDB
        log_to_wandb = False,
        
        # Misc
        use_tqdm = True,
        device = "cuda",
        seed = 42,
        n_checkpoints = 1,
        checkpoint_path = f"pythia-saes/l1_{l1_coeff:.2}", # change as you please
        dtype = torch.float32,
    )
    
    print(f"About to start training with lr {lr} and l1 {l1_coeff}")
    print(f"Checkpoint path: {cfg.checkpoint_path}")
    print(cfg)
    
    loader = LMSparseAutoencoderSessionloader(cfg)
    model, sparse_autoencoder, activations_loader = loader.load_session()
    
    # train SAE
    sparse_autoencoder = train_sae_on_language_model(
        model, sparse_autoencoder, activations_loader,
        n_checkpoints=cfg.n_checkpoints,
        batch_size = cfg.train_batch_size,
        feature_sampling_method = cfg.feature_sampling_method,
        feature_sampling_window = cfg.feature_sampling_window,
        feature_reinit_scale = cfg.feature_reinit_scale,
        dead_feature_threshold = cfg.dead_feature_threshold,
        dead_feature_window=cfg.dead_feature_window,
        use_wandb = cfg.log_to_wandb,
        wandb_log_frequency = cfg.wandb_log_frequency,
        dead_features_stopping_percentage = 0.6, # stop if over 60% of features are dead
    )
    
    # save sae to checkpoints folder
    path = f"{cfg.checkpoint_path}/final_{sparse_autoencoder.get_name()}.pt"
    sparse_autoencoder.save_model(path)