"""
This file is just a copy of the original config.py from the Transcoder paper's repo.
https://github.com/jacobdunefsky/transcoder_circuits/tree/master
"""

import gzip
import os
import pickle
from functools import partial

import einops
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from tqdm import tqdm
from transformer_lens.hook_points import HookedRootModule, HookPoint

# from .geom_median.src.geom_median.torch import compute_geometric_median


class SparseAutoencoder(HookedRootModule):
    """ """

    def __init__(
        self,
        cfg,
    ):
        super().__init__()
        self.cfg = cfg
        self.d_in = cfg.d_in
        if not isinstance(self.d_in, int):
            raise ValueError(
                f"d_in must be an int but was {self.d_in=}; {type(self.d_in)=}"
            )
        self.d_sae = cfg.d_sae
        self.l1_coefficient = cfg.l1_coefficient
        self.dtype = cfg.dtype
        self.device = cfg.device

        # transcoder stuff
        self.d_out = self.d_in
        if cfg.is_transcoder and cfg.d_out is not None:
            self.d_out = cfg.d_out

        # sparse-connection transcoder stuff
        self.spacon_sae_W_dec = None
        if cfg.is_sparse_connection:
            # load in the sae decoder weights that we'll use to train sparse connections
            sparse_connection_sae_path = cfg.sparse_connection_sae_path

            if sparse_connection_sae_path.endswith(".pt"):
                state_dict = torch.load(sparse_connection_sae_path)
            elif sparse_connection_sae_path.endswith(".pkl.gz"):
                with gzip.open(sparse_connection_sae_path, "rb") as f:
                    state_dict = pickle.load(f)
            elif sparse_connection_sae_path.endswith(".pkl"):
                with open(sparse_connection_sae_path, "rb") as f:
                    state_dict = pickle.load(f)
            else:
                raise ValueError(
                    f"Unexpected file extension: {sparse_connection_sae_path}, supported extensions are .pt, .pkl, and .pkl.gz"
                )

            self.spacon_sae_W_dec = (
                state_dict["state_dict"]["W_dec"].cuda()
                if not cfg.sparse_connection_use_W_enc
                else state_dict["state_dict"]["W_enc"].cuda().T
            )
            del state_dict
            torch.cuda.empty_cache()

        # NOTE: if using resampling neurons method, you must ensure that we initialise the weights in the order W_enc, b_enc, W_dec, b_dec
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(self.d_in, self.d_sae, dtype=self.dtype, device=self.device)
            )
        )
        self.b_enc = nn.Parameter(
            torch.zeros(self.d_sae, dtype=self.dtype, device=self.device)
        )

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.d_sae, self.d_out, dtype=self.dtype, device=self.device
                )
            )
        )

        with torch.no_grad():
            # Anthropic normalize this to have unit columns
            self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

        self.b_dec = nn.Parameter(
            torch.zeros(self.d_in, dtype=self.dtype, device=self.device)
        )

        self.b_dec_out = None
        if cfg.is_transcoder:
            self.b_dec_out = nn.Parameter(
                torch.zeros(self.d_out, dtype=self.dtype, device=self.device)
            )

        self.hook_sae_in = HookPoint()
        self.hook_hidden_pre = HookPoint()
        self.hook_hidden_post = HookPoint()
        self.hook_sae_out = HookPoint()

        self.setup()  # Required for `HookedRootModule`s

    def forward(self, x, dead_neuron_mask=None, mse_target=None):
        # move x to correct dtype
        x = x.to(self.dtype)
        sae_in = self.hook_sae_in(
            x - self.b_dec
        )  # Remove encoder bias as per Anthropic

        hidden_pre = self.hook_hidden_pre(
            einops.einsum(
                sae_in,
                self.W_enc,
                "... d_in, d_in d_sae -> ... d_sae",
            )
            + self.b_enc
        )
        feature_acts = self.hook_hidden_post(torch.nn.functional.relu(hidden_pre))

        if self.cfg.is_transcoder:
            # dumb if statement to deal with transcoders
            # hopefully branch prediction takes care of this
            sae_out = self.hook_sae_out(
                einops.einsum(
                    feature_acts,
                    self.W_dec,
                    "... d_sae, d_sae d_out -> ... d_out",
                )
                + self.b_dec_out
            )
        else:
            sae_out = self.hook_sae_out(
                einops.einsum(
                    feature_acts,
                    self.W_dec,
                    "... d_sae, d_sae d_out -> ... d_out",
                )
                + self.b_dec
            )

        # add config for whether l2 is normalized:
        if mse_target is None:
            mse_loss = (
                torch.pow((sae_out - x.float()), 2)
                / (x**2).sum(dim=-1, keepdim=True).sqrt()
            )
        else:
            mse_loss = (
                torch.pow((sae_out - mse_target.float()), 2)
                / (mse_target**2).sum(dim=-1, keepdim=True).sqrt()
            )
        mse_loss_ghost_resid = torch.tensor(0.0, dtype=self.dtype, device=self.device)
        # gate on config and training so evals is not slowed down.
        if self.cfg.use_ghost_grads and self.training and dead_neuron_mask.sum() > 0:
            assert dead_neuron_mask is not None

            # ghost protocol

            # 1.
            residual = x - sae_out
            l2_norm_residual = torch.norm(residual, dim=-1)

            # 2.
            feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
            ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
            l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
            norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
            ghost_out = ghost_out * norm_scaling_factor[:, None].detach()

            # 3.
            mse_loss_ghost_resid = (
                torch.pow((ghost_out - residual.detach().float()), 2)
                / (residual.detach() ** 2).sum(dim=-1, keepdim=True).sqrt()
            )
            mse_rescaling_factor = (mse_loss / (mse_loss_ghost_resid + 1e-6)).detach()
            mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid

        mse_loss_ghost_resid = mse_loss_ghost_resid.mean()
        mse_loss = mse_loss.mean()
        sparsity = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
        l1_loss = self.l1_coefficient * sparsity
        loss = mse_loss + l1_loss + mse_loss_ghost_resid

        return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid

    def get_sparse_connection_loss(self):
        dots = self.spacon_sae_W_dec @ self.W_dec.T
        # each row is an sae feature, each column is a transcoder feature
        loss = torch.sum(
            dots.abs(), dim=1
        ).mean()  # mean over each sae feature of L1 of transcoder features activated
        return self.cfg.sparse_connection_l1_coeff * loss

    @torch.no_grad()
    def initialize_b_dec(self, activation_store):
        if self.cfg.b_dec_init_method == "geometric_median":
            self.initialize_b_dec_with_geometric_median(activation_store)
        elif self.cfg.b_dec_init_method == "mean":
            self.initialize_b_dec_with_mean(activation_store)
        elif self.cfg.b_dec_init_method == "zeros":
            pass
        else:
            raise ValueError(
                f"Unexpected b_dec_init_method: {self.cfg.b_dec_init_method}"
            )

    @torch.no_grad()
    def initialize_b_dec_with_geometric_median(self, activation_store):
        assert self.cfg.is_transcoder == activation_store.cfg.is_transcoder

        previous_b_dec = self.b_dec.clone().cpu()
        all_activations = activation_store.storage_buffer.detach().cpu()
        out = compute_geometric_median(
            all_activations, skip_typechecks=True, maxiter=100, per_component=False
        ).median

        previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
        distances = torch.norm(all_activations - out, dim=-1)

        print("Reinitializing b_dec with geometric median of activations")
        print(
            f"Previous distances: {previous_distances.median(0).values.mean().item()}"
        )
        print(f"New distances: {distances.median(0).values.mean().item()}")

        out = torch.tensor(out, dtype=self.dtype, device=self.device)
        self.b_dec.data = out

        if self.b_dec_out is not None:
            # stupid code duplication
            previous_b_dec_out = self.b_dec_out.clone().cpu()
            all_activations_out = activation_store.storage_buffer_out.detach().cpu()
            out_out = compute_geometric_median(
                all_activations_out,
                skip_typechecks=True,
                maxiter=100,
                per_component=False,
            ).median

            previous_distances_out = torch.norm(
                all_activations_out - previous_b_dec_out, dim=-1
            )
            distances_out = torch.norm(all_activations_out - out_out, dim=-1)

            print("Reinitializing b_dec with geometric median of activations")
            print(
                f"Previous distances: {previous_distances_out.median(0).values.mean().item()}"
            )
            print(f"New distances: {distances_out.median(0).values.mean().item()}")

            out_out = torch.tensor(out_out, dtype=self.dtype, device=self.device)
            self.b_dec_out.data = out_out

    @torch.no_grad()
    def initialize_b_dec_with_mean(self, activation_store):
        assert self.cfg.is_transcoder == activation_store.cfg.is_transcoder

        previous_b_dec = self.b_dec.clone().cpu()
        all_activations = activation_store.storage_buffer.detach().cpu()
        out = all_activations.mean(dim=0)

        previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
        distances = torch.norm(all_activations - out, dim=-1)

        print("Reinitializing b_dec with mean of activations")
        print(
            f"Previous distances: {previous_distances.median(0).values.mean().item()}"
        )
        print(f"New distances: {distances.median(0).values.mean().item()}")

        self.b_dec.data = out.to(self.dtype).to(self.device)

        if self.b_dec_out is not None:
            # stupid code duplication
            previous_b_dec_out = self.b_dec_out.clone().cpu()
            all_activations_out = activation_store.storage_buffer_out.detach().cpu()
            out_out = all_activations_out.mean(dim=0)

            previous_distances_out = torch.norm(
                all_activations_out - previous_b_dec_out, dim=-1
            )
            distances_out = torch.norm(all_activations_out - out_out, dim=-1)

            print("Reinitializing b_dec with mean of activations")
            print(
                f"Previous distances: {previous_distances_out.median(0).values.mean().item()}"
            )
            print(f"New distances: {distances_out.median(0).values.mean().item()}")

            self.b_dec_out.data = out_out.to(self.dtype).to(self.device)

    @torch.no_grad()
    def resample_neurons_l2(
        self,
        x: Float[Tensor, "batch_size n_hidden"],
        feature_sparsity: Float[Tensor, "n_hidden_ae"],
        optimizer: torch.optim.Optimizer,
    ) -> None:
        """
        Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`.

        I'll probably break this now and fix it later!
        """

        feature_reinit_scale = self.cfg.feature_reinit_scale

        sae_out, _, _, _, _ = self.forward(x)
        per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze()

        # Find the dead neurons in this instance. If all neurons are alive, continue
        is_dead = feature_sparsity < self.cfg.dead_feature_threshold
        dead_neurons = torch.nonzero(is_dead).squeeze(-1)
        alive_neurons = torch.nonzero(~is_dead).squeeze(-1)
        n_dead = dead_neurons.numel()

        if n_dead == 0:
            return 0  # If there are no dead neurons, we don't need to resample neurons

        # Compute L2 loss for each element in the batch
        # TODO: Check whether we need to go through more batches as features get sparse to find high l2 loss examples.
        if per_token_l2_loss.max() < 1e-6:
            return 0  # If we have zero reconstruction loss, we don't need to resample neurons

        # Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss squared
        per_token_l2_loss = per_token_l2_loss.to(
            torch.float32
        )  # wont' work with bfloat16
        distn = Categorical(
            probs=per_token_l2_loss.pow(2) / (per_token_l2_loss.pow(2).sum())
        )
        replacement_indices = distn.sample((n_dead,))  # shape [n_dead]

        # Index into the batch of hidden activations to get our replacement values
        replacement_values = (x - self.b_dec)[
            replacement_indices
        ]  # shape [n_dead n_input_ae]

        # unit norm
        replacement_values = replacement_values / (
            replacement_values.norm(dim=1, keepdim=True) + 1e-8
        )

        # St new decoder weights
        self.W_dec.data[is_dead, :] = replacement_values

        # Get the norm of alive neurons (or 1.0 if there are no alive neurons)
        W_enc_norm_alive_mean = (
            1.0
            if len(alive_neurons) == 0
            else self.W_enc[:, alive_neurons].norm(dim=0).mean().item()
        )

        # Lastly, set the new weights & biases
        self.W_enc.data[:, is_dead] = (
            replacement_values * W_enc_norm_alive_mean * feature_reinit_scale
        ).T
        self.b_enc.data[is_dead] = 0.0

        # reset the Adam Optimiser for every modified weight and bias term
        # Reset all the Adam parameters
        for dict_idx, (k, v) in enumerate(optimizer.state.items()):
            for v_key in ["exp_avg", "exp_avg_sq"]:
                if dict_idx == 0:
                    assert k.data.shape == (self.d_in, self.d_sae)
                    v[v_key][:, is_dead] = 0.0
                elif dict_idx == 1:
                    assert k.data.shape == (self.d_sae,)
                    v[v_key][is_dead] = 0.0
                elif dict_idx == 2:
                    assert k.data.shape == (self.d_sae, self.d_out)
                    v[v_key][is_dead, :] = 0.0
                elif dict_idx == 3:
                    assert k.data.shape == (self.d_out,)
                else:
                    if not self.cfg.is_transcoder:
                        raise ValueError(f"Unexpected dict_idx {dict_idx}")
                        # if we're a transcoder, then this is fine, because we also have b_dec_out

        # Check that the opt is really updated
        for dict_idx, (k, v) in enumerate(optimizer.state.items()):
            for v_key in ["exp_avg", "exp_avg_sq"]:
                if dict_idx == 0:
                    if k.data.shape != (self.d_in, self.d_sae):
                        print(
                            "Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches"
                        )
                    if v[v_key][:, is_dead].abs().max().item() > 1e-6:
                        print(
                            "Warning: it does not seem as if resetting the Adam parameters worked"
                        )

        return n_dead

    @torch.no_grad()
    def resample_neurons_anthropic(
        self, dead_neuron_indices, model, optimizer, activation_store
    ):
        """
        Arthur's version of Anthropic's feature resampling
        procedure.
        """
        # collect global loss increases, and input activations
        global_loss_increases, global_input_activations = (
            self.collect_anthropic_resampling_losses(model, activation_store)
        )

        # sample according to losses
        probs = global_loss_increases / global_loss_increases.sum()
        sample_indices = torch.multinomial(
            probs,
            min(len(dead_neuron_indices), probs.shape[0]),
            replacement=False,
        )
        # if we don't have enough samples for for all the dead neurons, take the first n
        if sample_indices.shape[0] < len(dead_neuron_indices):
            dead_neuron_indices = dead_neuron_indices[: sample_indices.shape[0]]

        # Replace W_dec with normalized differences in activations
        self.W_dec.data[dead_neuron_indices, :] = (
            (
                global_input_activations[sample_indices]
                / torch.norm(
                    global_input_activations[sample_indices], dim=1, keepdim=True
                )
            )
            .to(self.dtype)
            .to(self.device)
        )

        # Lastly, set the new weights & biases
        self.W_enc.data[:, dead_neuron_indices] = self.W_dec.data[
            dead_neuron_indices, :
        ].T
        self.b_enc.data[dead_neuron_indices] = 0.0

        # Reset the Encoder Weights
        if dead_neuron_indices.shape[0] < self.d_sae:
            sum_of_all_norms = torch.norm(self.W_enc.data, dim=0).sum()
            sum_of_all_norms -= len(dead_neuron_indices)
            average_norm = sum_of_all_norms / (self.d_sae - len(dead_neuron_indices))
            self.W_enc.data[:, dead_neuron_indices] *= (
                self.cfg.feature_reinit_scale * average_norm
            )

            # Set biases to resampled value
            relevant_biases = self.b_enc.data[dead_neuron_indices].mean()
            self.b_enc.data[dead_neuron_indices] = (
                relevant_biases * 0
            )  # bias resample factor (put in config?)

        else:
            self.W_enc.data[:, dead_neuron_indices] *= self.cfg.feature_reinit_scale
            self.b_enc.data[dead_neuron_indices] = -5.0

        # TODO: Refactor this resetting to be outside of resampling.
        # reset the Adam Optimiser for every modified weight and bias term
        # Reset all the Adam parameters
        for dict_idx, (k, v) in enumerate(optimizer.state.items()):
            for v_key in ["exp_avg", "exp_avg_sq"]:
                if dict_idx == 0:
                    assert k.data.shape == (self.d_in, self.d_sae)
                    v[v_key][:, dead_neuron_indices] = 0.0
                elif dict_idx == 1:
                    assert k.data.shape == (self.d_sae,)
                    v[v_key][dead_neuron_indices] = 0.0
                elif dict_idx == 2:
                    assert k.data.shape == (self.d_sae, self.d_out)
                    v[v_key][dead_neuron_indices, :] = 0.0
                elif dict_idx == 3:
                    assert k.data.shape == (self.d_out,)
                else:
                    if not self.cfg.is_transcoder:
                        raise ValueError(f"Unexpected dict_idx {dict_idx}")
                        # if we're a transcoder, then this is fine, because we also have b_dec_out

        # Check that the opt is really updated
        for dict_idx, (k, v) in enumerate(optimizer.state.items()):
            for v_key in ["exp_avg", "exp_avg_sq"]:
                if dict_idx == 0:
                    if k.data.shape != (self.d_in, self.d_sae):
                        print(
                            "Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches"
                        )
                    if v[v_key][:, dead_neuron_indices].abs().max().item() > 1e-6:
                        print(
                            "Warning: it does not seem as if resetting the Adam parameters worked"
                        )

        return

    @torch.no_grad()
    def collect_anthropic_resampling_losses(self, model, activation_store):
        """
        Collects the losses for resampling neurons (anthropic)
        """

        batch_size = self.cfg.store_batch_size

        # we're going to collect this many forward passes
        number_final_activations = self.cfg.resample_batches * batch_size
        # but have seq len number of tokens in each
        number_activations_total = number_final_activations * self.cfg.context_size
        anthropic_iterator = range(0, number_final_activations, batch_size)
        anthropic_iterator = tqdm(
            anthropic_iterator, desc="Collecting losses for resampling..."
        )

        global_loss_increases = torch.zeros(
            (number_final_activations,), dtype=self.dtype, device=self.device
        )
        global_input_activations = torch.zeros(
            (number_final_activations, self.d_in), dtype=self.dtype, device=self.device
        )

        for refill_idx in anthropic_iterator:
            # get a batch, calculate loss with/without using SAE reconstruction.
            batch_tokens = activation_store.get_batch_tokens()
            ce_loss_with_recons = self.get_test_loss(batch_tokens, model)
            ce_loss_without_recons, normal_activations_cache = model.run_with_cache(
                batch_tokens,
                names_filter=self.cfg.hook_point,
                return_type="loss",
                loss_per_token=True,
            )
            # ce_loss_without_recons = model.loss_fn(normal_logits, batch_tokens, True)
            # del normal_logits

            normal_activations = normal_activations_cache[self.cfg.hook_point]
            if self.cfg.hook_point_head_index is not None:
                normal_activations = normal_activations[
                    :, :, self.cfg.hook_point_head_index
                ]

            # calculate the difference in loss
            changes_in_loss = ce_loss_with_recons - ce_loss_without_recons
            changes_in_loss = changes_in_loss.cpu()

            # sample from the loss differences
            probs = F.relu(changes_in_loss) / F.relu(changes_in_loss).sum(
                dim=1, keepdim=True
            )
            changes_in_loss_dist = Categorical(probs)
            samples = changes_in_loss_dist.sample()

            assert samples.shape == (batch_size,), (
                f"{samples.shape=}; {self.cfg.store_batch_size=}"
            )

            end_idx = refill_idx + batch_size
            global_loss_increases[refill_idx:end_idx] = changes_in_loss[
                torch.arange(batch_size), samples
            ]
            global_input_activations[refill_idx:end_idx] = normal_activations[
                torch.arange(batch_size), samples
            ]

        return global_loss_increases, global_input_activations

    @torch.no_grad()
    def get_test_loss(self, batch_tokens, model):
        """
        A method for running the model with the SAE activations in order to return the loss.
        returns per token loss when activations are substituted in.
        """

        if not self.cfg.is_transcoder:
            head_index = self.cfg.hook_point_head_index

            def standard_replacement_hook(activations, hook):
                activations = self.forward(activations)[0].to(activations.dtype)
                return activations

            def head_replacement_hook(activations, hook):
                new_actions = self.forward(activations[:, :, head_index])[0].to(
                    activations.dtype
                )
                activations[:, :, head_index] = new_actions
                return activations

            replacement_hook = (
                standard_replacement_hook
                if head_index is None
                else head_replacement_hook
            )

            ce_loss_with_recons = model.run_with_hooks(
                batch_tokens,
                return_type="loss",
                fwd_hooks=[(self.cfg.hook_point, replacement_hook)],
            )
        else:
            # TODO: currently, this only works with MLP transcoders
            assert "mlp" in self.cfg.out_hook_point

            old_mlp = model.blocks[self.cfg.hook_point_layer]

            class TranscoderWrapper(torch.nn.Module):
                def __init__(self, transcoder):
                    super().__init__()
                    self.transcoder = transcoder

                def forward(self, x):
                    return self.transcoder(x)[0]

            model.blocks[self.cfg.hook_point_layer].mlp = TranscoderWrapper(self)
            ce_loss_with_recons = model.run_with_hooks(batch_tokens, return_type="loss")
            model.blocks[self.cfg.hook_point_layer] = old_mlp

        return ce_loss_with_recons

    @torch.no_grad()
    def set_decoder_norm_to_unit_norm(self):
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

    @torch.no_grad()
    def remove_gradient_parallel_to_decoder_directions(self):
        """
        Update grads so that they remove the parallel component
            (d_sae, d_in) shape
        """

        parallel_component = einops.einsum(
            self.W_dec.grad,
            self.W_dec.data,
            "d_sae d_out, d_sae d_out -> d_sae",
        )

        self.W_dec.grad -= einops.einsum(
            parallel_component,
            self.W_dec.data,
            "d_sae, d_sae d_out -> d_sae d_out",
        )

    def save_model(self, path: str):
        """
        Basic save function for the model. Saves the model's state_dict and the config used to train it.
        """

        # check if path exists
        folder = os.path.dirname(path)
        os.makedirs(folder, exist_ok=True)

        state_dict = {"cfg": self.cfg, "state_dict": self.state_dict()}

        if path.endswith(".pt"):
            torch.save(state_dict, path)
        elif path.endswith("pkl.gz"):
            with gzip.open(path, "wb") as f:
                pickle.dump(state_dict, f)
        else:
            raise ValueError(
                f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz"
            )

        print(f"Saved model to {path}")

    @classmethod
    def load_from_pretrained(cls, path: str):
        """
        Load function for the model. Loads the model's state_dict and the config used to train it.
        This method can be called directly on the class, without needing an instance.
        """

        # Ensure the file exists
        if not os.path.isfile(path):
            raise FileNotFoundError(f"No file found at specified path: {path}")

        # Load the state dictionary
        if path.endswith(".pt"):
            try:
                if torch.backends.mps.is_available():
                    state_dict = torch.load(
                        path, map_location="mps", weights_only=False
                    )
                    state_dict["cfg"].device = "mps"
                else:
                    state_dict = torch.load(path, weights_only=False)
            except Exception as e:
                raise IOError(f"Error loading the state dictionary from .pt file: {e}")

        elif path.endswith(".pkl.gz"):
            try:
                with gzip.open(path, "rb") as f:
                    state_dict = pickle.load(f)
            except Exception as e:
                raise IOError(
                    f"Error loading the state dictionary from .pkl.gz file: {e}"
                )
        elif path.endswith(".pkl"):
            try:
                with open(path, "rb") as f:
                    state_dict = pickle.load(f)
            except Exception as e:
                raise IOError(f"Error loading the state dictionary from .pkl file: {e}")
        else:
            raise ValueError(
                f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz"
            )

        # Ensure the loaded state contains both 'cfg' and 'state_dict'
        if "cfg" not in state_dict or "state_dict" not in state_dict:
            raise ValueError(
                "The loaded state dictionary must contain 'cfg' and 'state_dict' keys"
            )

        # Create an instance of the class using the loaded configuration
        instance = cls(cfg=state_dict["cfg"])
        instance.load_state_dict(state_dict["state_dict"])

        return instance

    def get_name(self):
        sae_name = f"sparse_autoencoder_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}"
        return sae_name
