import transformer_lens.utils as utils
import torch 

def get_default_cfg():
    default_cfg = {
        "seed": 49,
        "batch_size": 4096,
        "lr": 3e-4,
        "num_tokens": int(1e9),
        "l1_coeff": 0,
        "beta1": 0.9,
        "beta2": 0.99,
        "max_grad_norm": 100000,
        "seq_len": 128,
        "dtype": torch.float32,
        "model_dtype": torch.float32,
        "model_name": "gpt2-small",
        "site": "resid_pre",
        "layer": 3,
        "act_size": 2304,
        "dict_size": 12288,
        "device": "cuda:0",
        "model_batch_size": 512,
        "num_batches_in_buffer": 10,
        "dataset_path": "Skylion007/openwebtext",
        "wandb_project": "sparse_autoencoders_1",
        "input_unit_norm": True,
        "perf_log_freq": 200,
        'perf_log_freq_base_metrics': 20, 
        'perf_log_freq_cosine_sim': 400,
        "sae_type": "topk",
        "checkpoint_freq": 20000,
        "n_batches_to_dead": 200,

        # (Batch)TopKSAE specific
        "top_k": 32,
        "top_k_aux": 512,
        "aux_penalty": (1/32),
        
        # for jumprelu
        "bandwidth": 0.001,
        
        'ort_coeff': 0.0, 
    }
    default_cfg = post_init_cfg(default_cfg)
    return default_cfg

def post_init_cfg(cfg):
    cfg["hook_point"] = utils.get_act_name(cfg["site"], cfg["layer"])
    cfg["name"] = f"{cfg['model_name']}_{cfg['hook_point']}_{cfg['dict_size']}_{cfg['sae_type']}_{cfg['top_k']}_{cfg['l1_coeff']}_{cfg['ort_coeff']}"
    return cfg