#%%
import torch
import wandb
import transformer_lens.utils as utils
from training import train_sae
from sae import *
from activation_store import ActivationsStore
from sae import TrainingConfig, SAEConfig
from sae.config import post_init_cfg
from transformer_lens import HookedTransformer
import pyrallis
import ast
import os

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')

@pyrallis.wrap()
def main(cfg: TrainingConfig):
    if cfg.sweep_pair:
        for key, value in cfg.sweep_pair.items():
            setattr(cfg, key, value)

    if cfg.enable_wandb:
        wandb.init(
            project=cfg.wandb_project,
            config=vars(cfg),
            allow_val_change=True,
            name=f"{cfg.dict_size}_{cfg.sae_type}_{cfg.wandb_run_suffix}"
        )
        
    loaded_model = False
    while not loaded_model:
        try:
            model = HookedTransformer.from_pretrained_no_processing(cfg.model_name, dtype=torch.bfloat16).to(cfg.device)
            loaded_model = True
        except Exception as e:
            import time
            print(f"Error loading model: {e}")
            time.sleep(10)

    cfg.act_size = model.cfg.d_model
    
    print(f"act_size: {cfg.act_size}")
    if "," in cfg.hook_point:
        cfg.hook_point = ast.literal_eval(cfg.hook_point)
        if  isinstance(cfg.hook_point , list):
            cfg.hook_point = tuple(cfg.hook_point)
    activations_store = ActivationsStore(model, cfg)
    cfg = post_init_cfg(cfg, activations_store)
    
    model_config = SAEConfig.from_training_config(cfg)
    if cfg.sae_type == "vanilla":
        sae = VanillaSAE(model_config)
    elif cfg.sae_type == "topk":
        sae = TopKSAE(model_config)
    elif cfg.sae_type == "batchtopk":
        sae = BatchTopKSAE(model_config)
    elif cfg.sae_type == 'jumprelu':
        sae = JumpReLUSAE(model_config)
    elif cfg.sae_type == "kronsae":
        model_config.cartesian_op = "mul"
        sae = KronSAE(model_config)
    else:
        raise ValueError(f"SAE type {cfg.sae_type} is not understood.")
    
    train_sae(sae, activations_store, model, cfg, train_transcoder=isinstance(cfg.hook_point , tuple))


if __name__ == "__main__":
    main()