import os
from modules import constants
os.environ["HF_HOME"] = os.path.join(constants.DATA_DIR, "raw")

import torch
import argparse
import random
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

# Parse command line arguments
parser = argparse.ArgumentParser(description="Run SAE training.")
parser.add_argument("--run_name", type=str, default="baseline_sae", help="Name of the run.")
parser.add_argument("--model_name", type=str, default="pythia-160m-deduped", help="Name of the model to use.")
parser.add_argument("--dataset_name", type=str, default="EleutherAI/fineweb-edu-dedup-10b", help="Name of the dataset to use.")
parser.add_argument("--hook_name", type=str, default="blocks.6.hook_resid_pre", help="Name of the hook to use.")
parser.add_argument("--hook_layer", type=int, default=6, help="Layer number of the hook.")
parser.add_argument("--d_in", type=int, default=768, help="Input dimension.")
parser.add_argument("--expansion_factor", type=int, default=32, help="Input dimension.")
parser.add_argument("--total_training_steps", type=int, default=100000, help="Total training steps.")
parser.add_argument("--batch_size", type=int, default=2048, help="Batch size.")

# random run tag of letters and digits
random_tag = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=5))

args = parser.parse_args()

print(f"Running with model: {args.model_name}")
print(f"Running with dataset: {args.dataset_name}")
print(f"Running with hook: {args.hook_name}")
print(f"Latent size: {args.d_in * args.expansion_factor}")

total_training_tokens = args.total_training_steps * args.batch_size

print(f"Total training tokens: {total_training_tokens}")

lr_warm_up_steps = 0
lr_decay_steps = args.total_training_steps // 5  # 20% of training
l1_warm_up_steps = args.total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    model_name=args.model_name,
    # model_class_name="AutoModelForCausalLM", # REMOVE FOR TRANSFORMER LENS MODELS
    hook_name=args.hook_name,  
    hook_layer=args.hook_layer,  
    d_in=args.d_in,  
    dataset_path=args.dataset_name,  
    is_dataset_tokenized=False,
    streaming=False,  
    # SAE Parameters
    mse_loss_normalization=None,  
    expansion_factor=args.expansion_factor,  
    b_dec_init_method="zeros", 
    apply_b_dec_to_input=False, 
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5, 
    adam_beta1=0.9, 
    adam_beta2=0.999,
    lr_scheduler_name="constant",  
    lr_warm_up_steps=lr_warm_up_steps, 
    lr_decay_steps=lr_decay_steps, 
    l1_coefficient=5,  
    l1_warm_up_steps=l1_warm_up_steps,  
    lp_norm=1.0,  
    train_batch_size_tokens=args.batch_size,
    context_size=2048,  
    # Activation Store Parameters
    n_batches_in_buffer=32,  
    training_tokens=total_training_tokens,  
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  
    feature_sampling_window=1000,  
    dead_feature_window=1000,  
    dead_feature_threshold=1e-4,  
    # WANDB
    log_to_wandb=True,  
    wandb_project="Noise_SAE",
    run_name=f"{args.run_name}_{args.model_name}_{args.hook_name}_{random_tag}",  
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device="cuda",
    seed=42,
    n_checkpoints=0,
    checkpoint_path=os.path.join(constants.MODEL_DIR, "baseline_saes", f"{args.run_name}_{args.model_name}_{args.hook_name}_ef{args.expansion_factor}"),
    dtype="float32",
    # autocast=True,
    # autocast_lm=True,
    compile_sae=True,
    compile_llm=True 
)

sparse_autoencoder = SAETrainingRunner(cfg).run()

# Save the config to a file
cfg_file_path = os.path.join(cfg.checkpoint_path, "sae_training_config.json")
os.makedirs(cfg.checkpoint_path, exist_ok=True)
cfg.save_to_json(cfg_file_path)
