from typing import Tuple

import torch
from sae.config import Config
from sae.hooked_vit import HookedVisionTransformer
from sae.sparse_autoencoder import SparseAutoencoder
from transformers import CLIPModel, CLIPProcessor


class ModelLoader:
    """A utility class for loading SAE and ViT models."""

    @staticmethod
    def load_sae(sae_path: str, device: str) -> Tuple[SparseAutoencoder, Config]:
        """
        Load a sparse autoencoder model from a checkpoint file.

        """
        checkpoint = torch.load(sae_path, map_location="cpu")

        try:
            cfg = Config(checkpoint["cfg"])
        except:
            cfg = Config(checkpoint["config"])

        sae = SparseAutoencoder(cfg, device)
        sae.load_state_dict(checkpoint["state_dict"])
        sae.eval().to(device)

        return sae, cfg

    @staticmethod
    def get_sae_and_vit(
        sae_path: str,
        device: str,
        backbone: str,
    ) -> Tuple[SparseAutoencoder, HookedVisionTransformer, Config]:
        """
        Load both SAE and ViT models.
        """
        sae, cfg = ModelLoader.load_sae(sae_path, device)
        model = CLIPModel.from_pretrained(backbone)
        processor = CLIPProcessor.from_pretrained(backbone)
        vit = HookedVisionTransformer(model, processor, device=device)

        return sae, vit, cfg
