import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer


def load_sae_saelens(
    release: str = "gemma-scope-9b-it-res-canonical",
    sae_id: str = "layer_9/width_131k/canonical",
    device: str = "cuda",
    dtype: str = "bfloat16",
):
    sae, _, _ = SAE.from_pretrained(release=release, sae_id=sae_id, device=device)

    torch_dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
    sae = sae.to(torch_dtype)
    return sae


def load_model_tlens(
    model_name: str = "google/gemma-2b-it",
    device: str = "cuda",
    dtype: str = "bfloat16",
    load_with_no_processing: bool = False,
):

    if load_with_no_processing:
        model = HookedTransformer.from_pretrained_no_processing(
            model_name, dtype=dtype, device=device
        )
    else:
        model = HookedTransformer.from_pretrained(
            model_name, dtype=dtype, device=device
        )

    return model
