import json
import math
import os
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Type, TypeVar, Union

import torch as t
from einops import einsum, rearrange
from jaxtyping import Float
from loguru import logger
from torch import nn
from torch.nn import functional as F

from auto_encoder.config import AutoEncoderConfig
from auto_encoder.config_enums import AutoEncoderType
from auto_encoder.helpers.sinkhorn_balancing import sinkhorn_balancing_dual
from auto_encoder.helpers.zipf_m import MCounts, calculate_m, get_m_counts

if t.cuda.is_available():
    from auto_encoder.models.model_helpers.kernels import TritonDecoderAutograd

from auto_encoder import device

if sys.version_info >= (3, 11):
    from typing import Self
else:
    # Define 'Self' as a TypeVar for older versions
    Self = TypeVar("Self", bound="AutoEncoderBase")


@dataclass
class TopKOutput:
    features_BSF: t.Tensor
    values_BsK: t.Tensor
    indices_BsK: t.Tensor


@dataclass
class TopMActOutput:
    features_BSF: t.Tensor
    values_MF: Union[t.Tensor, list[t.Tensor]]
    indices_MF: Union[t.Tensor, list[t.Tensor]]


class FeatureActivityQueue:
    def __init__(self, max_length: int):
        self.max_length = max_length
        self.queue: list[t.Tensor] = []

    def append(self, current_feature_activity_F: t.Tensor) -> None:
        if self.queue_full():
            self.queue.pop(0)
        self.queue.append(current_feature_activity_F)

    def recent_activity(self) -> t.Tensor:
        stacked_queue_QF = t.stack(self.queue, dim=0)
        return t.sum(stacked_queue_QF, dim=0)

    def __len__(self) -> int:
        return len(self.queue)

    def queue_full(self) -> bool:
        return len(self.queue) == self.max_length


class AutoEncoderBase(nn.Module, ABC):
    """
    The base class for all autoencoders. This class defines the
    preprocessing, encoding, activation function, decoding and
    postprocessing steps of the autoencoder.

    The encoding, activation function and forward methods are to be
    implemented by subclasses.

    It is not recommended to overwrite the preprocess, decode or
    postprocess methods, as these are common to all autoencoders.
    """

    # encoder: nn.Module

    def __init__(
        self,
        config: AutoEncoderConfig,
        medoid_initial_tensor_N: Optional[Float[t.Tensor, "num_neurons"]],
        preprocess_scaling_factor: Optional[float],
        device: str,
    ):
        super().__init__()
        self.config = config
        self.device = device

        self.base_path = "artifacts/auto_encoder"

        self.num_batches_inactive_F = t.zeros(config.num_features)
        self.feature_activation_queue = FeatureActivityQueue(
            config.feature_activity_queue_length
        )

        ## PREPROCESSING

        # Using tied bias term for encoder and decoder
        if medoid_initial_tensor_N is not None and config.preprocess_strategy.median_shift:
            self.neuron_bias_N = nn.Parameter(medoid_initial_tensor_N)
        else:
            self.neuron_bias_N = nn.Parameter(t.zeros(config.num_neurons))

        if preprocess_scaling_factor is not None and config.preprocess_strategy.scale_MAD:
            PROBABLE_ERROR = 0.6745  # https://en.wikipedia.org/wiki/Probable_error
            # The MAD for a normal distribution is 1.4826 times the standard deviation
            # I.e. 1/PROBABLE_ERROR. We multiply this back to normalise
            scale = t.tensor(preprocess_scaling_factor) / PROBABLE_ERROR
            self.scale = nn.Parameter(scale)
        else:
            scale = t.tensor(1.0) / math.sqrt(config.num_neurons)
            self.scale = nn.Parameter(scale)

        self.preprocess_strategy = config.preprocess_strategy

        ## ENCODER
        # Filled in by subclass

        ## ACTIVATION FUNCTION
        # Filled in by subclass

        ### TOP-K / TOP-M

        self.stochastic_topk_temperature = config.stochastic_topk_temperature
        self.gumbel_weighting = config.topk_gumbel_weighting
        self.topk_cutoff_prob = config.topk_cutoff_prob
        self.topm_cutoff_prob = config.topm_cutoff_prob
        self.use_straight_through_estimator = config.use_straight_through_estimator
        self.topm_flavour = config.topm_flavour

        # if self.topm_flavour.value == "zipf":
        m_F = calculate_m(
            m=config.topm,
            num_features=config.num_features,
            num_tokens_per_batch=config.minibatch_size * config.seq_len,
        )
        self.feature_capacities_F = m_F

        if self.topm_flavour.value == "zipf":
            self.m_counts = get_m_counts(m_F)
            logger.info(f"Using Zipf top-m with m_counts: {self.m_counts}")

        self.topk = config.topk
        self.topm = config.topm

        self.use_sinkhorn_assignment = config.use_sinkhorn_assignment

        ## DECODER

        # rows of decoder weight matrix are unit vectors
        self.decoder = nn.Linear(config.num_features, config.num_neurons, bias=False)

        dec_weight = t.randn_like(self.decoder.weight)
        dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True)
        dec_weight = dec_weight.T.contiguous().T  # Ensure contiguous memory layout for kernel
        self.decoder.weight = nn.Parameter(dec_weight)

        self.baseline_nfm: Optional[t.Tensor] = None

        ## POSTPROCESSING

        self.postprocess_bias_N: Optional[nn.Parameter] = None
        if config.postprocess_bias:
            self.postprocess_bias_N = nn.Parameter(t.zeros(config.num_neurons))  # type: ignore

    def preprocess(
        self, x_BSN: Float[t.Tensor, "batch seq_len num_neurons"]
    ) -> tuple[Float[t.Tensor, "batch seq_len num_neurons"], Float[t.Tensor, "batch seq_len"]]:
        _batch_size, _seq_len, neuron_dim = x_BSN.shape

        normalising_constants_BS = t.norm(x_BSN, p=2, dim=-1)  # [batch_size, seq_len]
        normalising_constants_BS = normalising_constants_BS / math.sqrt(neuron_dim)

        if self.preprocess_strategy.normalise:
            # Replace any zeros with 1s to avoid division by zero
            normalising_constants_BS = t.where(
                normalising_constants_BS == 0.0, t.tensor(1.0), normalising_constants_BS
            )

            normalised_mlp_neuron_activations_BSN = (
                x_BSN - self.neuron_bias_N
            ) / normalising_constants_BS.unsqueeze(
                -1
            )  # Normalise activation vectors to have norm sqrt(num_neurons)
            return normalised_mlp_neuron_activations_BSN, normalising_constants_BS

        elif self.preprocess_strategy.scale:
            return (x_BSN - self.neuron_bias_N) / self.scale, normalising_constants_BS

        elif self.preprocess_strategy.shift:
            return x_BSN - self.neuron_bias_N, normalising_constants_BS

        else:
            return x_BSN, normalising_constants_BS

    @abstractmethod
    def encode(self, x_center_BSN: Float[t.Tensor, "batch seq_len num_neurons"]) -> Any: ...

    @abstractmethod
    def activation_fn(
        self, pre_activation_features_BSF: Float[t.Tensor, "batch seq_len num_features"]
    ) -> Float[t.Tensor, "batch seq_len num_features"]: ...

    def l1_loss(self, x: t.Tensor) -> t.Tensor:
        return t.mean(t.abs(x))

    ### TOPK

    def topk_maybe_stochastic(
        self,
        features_BSF: t.Tensor,
        ones_as_values: bool = False,
        topk: Optional[int] = None,
    ) -> TopKOutput:
        topk = topk if topk is not None else self.topk

        if self.training and self.stochastic_topk_temperature > 0:
            features_values_BSK, features_indices_Int_BSK = self.stochastic_topk(
                features_BSF, topk
            )
        else:
            features_values_BSK, features_indices_Int_BSK = t.topk(features_BSF, topk, dim=-1)

        features_values_BsK = rearrange(
            features_values_BSK, "batch seq_len k -> (batch seq_len) k"
        )
        features_indices_Int_BsK = rearrange(
            features_indices_Int_BSK, "batch seq_len k -> (batch seq_len) k"
        )

        mask_BSF = t.zeros_like(features_BSF)
        mask_BSF.scatter_(-1, features_indices_Int_BSK, 1.0)

        if ones_as_values:
            # If doing binary SAEs then we want the values to be 1
            features_BSF_masked = mask_BSF
        else:
            features_BSF_masked = features_BSF * mask_BSF

        if self.use_straight_through_estimator:
            features_BSF = features_BSF + (features_BSF_masked - features_BSF).detach()
        else:
            features_BSF = features_BSF_masked

        out = TopKOutput(
            features_BSF=features_BSF,
            values_BsK=features_values_BsK,
            indices_BsK=features_indices_Int_BsK,
        )

        return out

    def stochastic_topk(self, features_BSF: t.Tensor, topk: int) -> tuple[t.Tensor, t.Tensor]:
        """
        Applies the stochastic top-k operation to the input features.

        self.gumbel_weighting of 0 corresponds to a straight-through estimator.

        Args:
            features_BSF (torch.Tensor): Input features of shape (batch_size, seq_len, num_features).
            k (int): Number of indices to sample.

        Returns:
            torch.Tensor: Sampled values of shape (batch_size, seq_len, k).
            torch.Tensor: Sampled indices of shape (batch_size, seq_len, k).
        """
        batch_size, seq_len, num_features = features_BSF.shape

        topk = topk if topk is not None else self.topk

        # Add Gumbel noise to the features
        gumbels_BSF = -t.log(-t.log(t.rand_like(features_BSF) + 1e-12) + 1e-12)
        features_BSF = features_BSF + (gumbels_BSF * self.gumbel_weighting)

        # Convert features to probabilities
        probs_BSF = F.softmax((features_BSF) / self.stochastic_topk_temperature, dim=-1)

        # Don't pick any feature that has prob < 1 / num_features * 1e4
        probs_BSF = probs_BSF * (probs_BSF > self.topk_cutoff_prob).float()

        probs_BsF = rearrange(
            probs_BSF, "batch seq_len num_features -> (batch seq_len) num_features"
        )

        # Sample k indices without replacement
        try:
            sampled_indices_BsK = t.multinomial(probs_BsF, topk, replacement=False)
            sampled_indices_BSK = rearrange(
                sampled_indices_BsK, "(batch seq_len) k -> batch seq_len k", batch=batch_size
            )
        except RuntimeError:
            # If there aren't enough features which have a probability above the threshold then just pick the top k
            # logger.warning("Not enough features above the threshold, picking top k")
            sampled_indices_BSK = t.topk(features_BSF, topk, dim=-1).indices

        # Get the values of sampled indices
        sampled_values_BSK = t.gather(features_BSF, -1, sampled_indices_BSK)

        return sampled_values_BSK, sampled_indices_BSK

    ### TOP_M

    def top_m_maybe_stochastic(
        self,
        features_BsF: t.Tensor,
        ones_as_values: bool = False,
    ) -> TopMActOutput:
        """
        Feature-choice TopM operation.

        Applies either uniform or zipf top-m operation to the input features.
        Each feature chooses m tokens based on the selected method.

        Args:
            features_BSF (torch.Tensor): Input features of shape (batch_size, seq_len, num_features).
            m (int): Number of indices to sample.
            m_counts (list[MCounts]): List of MCounts objects containing m and count.
            method (str): 'uniform' or 'zipf' to select the method.
            ones_as_values (bool): If True, use 1 as values instead of actual feature values.
            straight_through (bool): If True, use straight-through estimator.

        Returns:
            TopMOutput: Containing sampled features, values, and indices.
        """

        batch_size, seq_len, num_features = features_BsF.shape

        def get_feature_mask(
            features_indices_list_Int_MF: list[t.Tensor], m_counts: list[MCounts]
        ):
            mask_BsF_list = []

            for m_count, features_indices_section_MF in zip(
                m_counts, features_indices_list_Int_MF, strict=True
            ):
                num_features_in_section = m_count.count

                mask_section_BsF = t.zeros(
                    (batch_size * seq_len, num_features_in_section), device=self.device
                )
                mask_section_BsF.scatter_(0, features_indices_section_MF, 1.0)

                mask_BsF_list.append(mask_section_BsF)

            mask_BsF = t.cat(mask_BsF_list, dim=1)
            return mask_BsF

        features_BsF = rearrange(
            features_BsF, "batch seq_len num_features -> (batch seq_len) num_features"
        )

        if self.topm_flavour.value == "uniform" and self.topm is not None:

            if self.training and self.stochastic_topk_temperature > 0:
                probs_BsF = self.get_stochastic_topm_probs(features_BsF)
                features_values_MF, features_indices_Int_MF = (
                    self._stochastic_sample_uniform_topm(
                        probs_BsF=probs_BsF, features_BsF=features_BsF
                    )
                )

            else:
                features_values_MF, features_indices_Int_MF = self._uniform_top_m(
                    features_BsF, self.topm
                )

            mask_BsF = t.zeros_like(features_BsF)
            mask_BsF.scatter_(0, features_indices_Int_MF, 1.0)

        elif self.topm_flavour.value == "zipf" and self.m_counts is not None:
            if self.use_sinkhorn_assignment:

                sinkhorn_assignments_BsF = sinkhorn_balancing_dual(
                    features_BsF, feature_capacities_F=self.feature_capacities_F
                )
                features_values_list_MF, features_indices_list_Int_MF = self._zipf_top_m(
                    sinkhorn_assignments_BsF, m_counts=self.m_counts  # type: ignore
                )

                mask_BsF = get_feature_mask(features_indices_list_Int_MF, self.m_counts)

            else:
                m_counts = self.m_counts

                if self.training and self.stochastic_topk_temperature > 0:
                    probs_BsF = self.get_stochastic_topm_probs(features_BsF)
                    features_values_list_MF, features_indices_list_Int_MF = self._zipf_top_m(
                        features_BsF, m_counts=m_counts, probs_BsF=probs_BsF, is_stochastic=True  # type: ignore
                    )

                else:
                    features_values_list_MF, features_indices_list_Int_MF = self._zipf_top_m(
                        features_BsF, m_counts=m_counts  # type: ignore
                    )

                # Set up several masks, same size as the counts
                # Fill them in and concatenate them together

                mask_BsF = get_feature_mask(features_indices_list_Int_MF, m_counts)

        else:
            raise ValueError(
                "Method must be either 'uniform' with m provided or 'zipf' with m_counts provided."
            )

        # Apply masking to features

        if ones_as_values:
            # If doing binary SAEs then we want the values to be 1
            features_BsF_masked = mask_BsF
        else:
            features_BsF_masked = features_BsF * mask_BsF

        if self.use_straight_through_estimator:
            features_BsF = features_BsF + (features_BsF_masked - features_BsF).detach()

            logger.debug(t.allclose(features_BsF, features_BsF_masked))
        else:
            features_BsF = features_BsF_masked

        features_BSF = rearrange(
            features_BsF,
            "(batch seq_len) num_features -> batch seq_len num_features",
            batch=batch_size,
        )

        out = TopMActOutput(
            features_BSF=features_BSF,
            values_MF=(
                features_values_MF
                if self.topm_flavour.value == "uniform"
                else features_values_list_MF
            ),
            indices_MF=(
                features_indices_Int_MF
                if self.topm_flavour.value == "uniform"
                else features_indices_list_Int_MF
            ),
        )

        return out

    def get_stochastic_topm_probs(self, features_BsF: t.Tensor) -> t.Tensor:
        # Add Gumbel noise to the features
        gumbels_BsF = -t.log(-t.log(t.rand_like(features_BsF) + 1e-12) + 1e-12)
        features_BsF = features_BsF + (gumbels_BsF * self.gumbel_weighting)

        # Convert features to probabilities
        probs_BsF = F.softmax((features_BsF) / self.stochastic_topk_temperature, dim=0)

        # Don't pick any feature that has prob < 1 / num_features * 1e4
        probs_BsF = probs_BsF * (probs_BsF > self.topm_cutoff_prob).float()

        return probs_BsF

    def _stochastic_sample_uniform_topm(
        self, probs_BsF: t.Tensor, features_BsF: t.Tensor
    ) -> tuple[t.Tensor, t.Tensor]:
        # Sample m indices without replacement
        probs_FBs = rearrange(
            probs_BsF, "batch_seq_len num_features -> num_features batch_seq_len"
        )

        try:
            sampled_indices_FM = t.multinomial(probs_FBs, self.topm, replacement=False)
            sampled_indices_MF = rearrange(
                sampled_indices_FM, "num_features m -> m num_features"
            )
        except RuntimeError:
            # If there aren't enough features which have a probability above the threshold then just pick the top k
            # logger.warning("Not enough features above the threshold, picking top k")
            try:
                sampled_indices_MF = t.topk(features_BsF, self.topm, dim=0).indices
            except Exception as e:
                # If the size of the batch isn't large enough (reaching the end of the dataset etc.) then take them all.
                logger.error(f"Error: {e}")
                bs, _ = features_BsF.shape
                sampled_indices_MF = t.topk(features_BsF, bs, dim=0).indices

        sampled_values_MF = t.gather(features_BsF, 0, sampled_indices_MF)

        return sampled_values_MF, sampled_indices_MF

    def _uniform_top_m(self, features_BsF: t.Tensor, m: int) -> tuple[t.Tensor, t.Tensor]:
        return t.topk(features_BsF, m, dim=0)

    def _zipf_top_m(
        self,
        features_BsF: t.Tensor,
        m_counts: list[MCounts],
        probs_BsF: Optional[t.Tensor] = None,
        is_stochastic: bool = False,
    ) -> tuple[list[t.Tensor], list[t.Tensor]]:
        features_values_list: list[t.Tensor] = []  # MF where M varies
        features_indices_list: list[t.Tensor] = []  # MF

        start = 0
        end = m_counts[0].count

        for idx, m_count in enumerate(m_counts):
            m = m_count.m_value
            count = m_count.count
            # logger.info(start)
            # logger.info(end)

            # logger.info(features_BsF.shape)
            if is_stochastic and (probs_BsF is not None):
                probs_FBs = rearrange(
                    probs_BsF, "batch_seq_len num_features -> num_features batch_seq_len"
                )

                try:
                    sampled_indices_FM = t.multinomial(
                        probs_FBs[start:end, :], m, replacement=False
                    )
                    sampled_indices_MF = rearrange(
                        sampled_indices_FM, "num_features m -> m num_features"
                    )
                except RuntimeError:
                    # If there aren't enough features which have a probability above the threshold then just pick the top k
                    # logger.warning("Not enough features above the threshold, picking top k")
                    sampled_indices_MF = t.topk(features_BsF[:, start:end], m, dim=0).indices

                sampled_values_MF = t.gather(features_BsF[:, start:end], 0, sampled_indices_MF)

                features_values_MF = sampled_values_MF
                features_indices_Int_MF = sampled_indices_MF

            elif not is_stochastic:
                features_values_MF, features_indices_Int_MF = t.topk(
                    features_BsF[:, start:end], m, dim=0
                )
            else:
                raise ValueError("To use stochastic sampling, probs_BsF must be provided.")

            features_values_list.append(features_values_MF)
            features_indices_list.append(features_indices_Int_MF)

            start += count
            end += m_counts[idx + 1].count if idx + 1 < len(m_counts) else 0

        return features_values_list, features_indices_list

    def aux_k_losses(
        self, x_BSN: t.Tensor, x_hat_BSN: t.Tensor, pre_activation_features_BSF: t.Tensor
    ) -> tuple[t.Tensor, t.Tensor]:
        if len(self.feature_activation_queue) < self.feature_activation_queue.max_length:
            zero_tensor = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)
            return zero_tensor, zero_tensor.clone()

        capacity_75_percentile = t.quantile(self.feature_capacities_F.float(), 0.25)

        feature_activation_queue_len = len(self.feature_activation_queue)

        feature_activity_F = self.feature_activation_queue.recent_activity()

        _, ordering = t.sort(feature_activity_F, descending=True)

        dead_features_F = feature_activity_F == 0
        num_dead_features = int(dead_features_F.sum().item())

        self.feature_capacities_F = self.feature_capacities_F.to(feature_activity_F.device)

        ordered_feature_capacities_F = self.feature_capacities_F[ordering]

        dying_features_F = (
            feature_activity_F
            < (ordered_feature_capacities_F * 0.6 * feature_activation_queue_len)
        ) & (feature_activity_F < capacity_75_percentile)

        num_dying_features = int(dying_features_F.sum().item())

        residual_BSN = x_hat_BSN - x_BSN

        def _aux_k_loss(inactive_features: t.Tensor, num_inactive_features: int) -> t.Tensor:
            topk = min(self.topk, num_inactive_features)

            auxiliary_topk_out = t.topk(
                pre_activation_features_BSF[:, :, inactive_features], topk, dim=-1
            )

            acts_aux_BSF = t.zeros_like(
                pre_activation_features_BSF[:, :, inactive_features]
            ).scatter(-1, auxiliary_topk_out.indices, auxiliary_topk_out.values)

            decoder_weights_NF = self.decoder.weight[:, inactive_features]
            x_hat_dead_aux_BSN = acts_aux_BSF @ decoder_weights_NF.T

            aux_k_reconstruction_loss = t.mean((residual_BSN - x_hat_dead_aux_BSN).pow(2))

            return aux_k_reconstruction_loss

        if dead_features_F.sum() > 0:
            aux_dead_reconstruction_loss = _aux_k_loss(dead_features_F, num_dead_features)
        else:
            aux_dead_reconstruction_loss = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)

        if dying_features_F.sum() > 0:
            aux_dying_reconstruction_loss = _aux_k_loss(dying_features_F, num_dying_features)
        else:
            aux_dying_reconstruction_loss = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)

        # If there are nans then return 0
        if t.isnan(aux_dead_reconstruction_loss):
            aux_dead_reconstruction_loss = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)
            logger.warning("aux_dead_reconstruction_loss is nan")
        if t.isnan(aux_dying_reconstruction_loss):
            aux_dying_reconstruction_loss = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)
            logger.warning("aux_dying_reconstruction_loss is nan")

        return aux_dead_reconstruction_loss, aux_dying_reconstruction_loss

    def decode(
        self,
        features_BSF: Optional[t.Tensor] = None,
        indices_BsK: Optional[t.Tensor] = None,
        values_BsK: Optional[t.Tensor] = None,
    ) -> Float[t.Tensor, "batch seq_len num_neurons"]:

        assert features_BSF is not None or (
            indices_BsK is not None and values_BsK is not None
        ), "Either features or indices and vals must be provided."

        if (
            self.config.use_decoder_kernel
            and indices_BsK is not None
            and values_BsK is not None
        ):
            matmul_precision = t.get_float32_matmul_precision()

            # Set float32 matmul precision to highest for kernel
            t.set_float32_matmul_precision("highest")

            raw_x_hat_BsN: t.Tensor = TritonDecoderAutograd.apply(
                indices_BsK, values_BsK, self.decoder.weight
            )  # type: ignore

            raw_x_hat_BSN = rearrange(
                raw_x_hat_BsN,
                "(batch seq_len) num_neurons -> batch seq_len num_neurons",
                seq_len=self.config.seq_len,
            )

            # Reset float32 matmul precision for the rest of the model
            t.set_float32_matmul_precision(matmul_precision)

        elif features_BSF is not None:
            raw_x_hat_BSN = self.decoder(features_BSF)

        else:
            raise ValueError("Either features or indices and vals must be provided.")

        return raw_x_hat_BSN

    def postprocess(
        self,
        raw_x_hat_BSN: Float[t.Tensor, "batch seq_len num_neurons"],
        normalising_constants_BS: Float[t.Tensor, "batch seq_len"],
    ) -> Float[t.Tensor, "batch seq_len num_neurons"]:
        if self.preprocess_strategy.normalise:
            x_hat_BSN = (
                raw_x_hat_BSN * normalising_constants_BS.unsqueeze(-1) + self.neuron_bias_N
            )

        if self.preprocess_strategy.scale:
            x_hat_BSN = raw_x_hat_BSN * self.scale + self.neuron_bias_N
        elif self.preprocess_strategy.shift:
            x_hat_BSN = raw_x_hat_BSN + self.neuron_bias_N
        else:
            x_hat_BSN = raw_x_hat_BSN

        if self.postprocess_bias_N is not None:
            return x_hat_BSN + self.postprocess_bias_N

        return x_hat_BSN

    def nfm_losses(self, possible_baseline_reset: bool = False) -> tuple[t.Tensor, t.Tensor]:
        """Computes the Neural Feature Matrix (NFM) loss, which is the Frobenius norm of the Gram matrix
        of the decoder's neural features (W W^T). We subtract the identity matrix to remove the diagonal
        elements to normalise the minimal value at 0.

        Returns
        -------
        t.Tensor
        """
        decoder_NF = self.decoder.weight
        decoder_NF_normed = F.normalize(decoder_NF, p=2, dim=0)

        num_neurons, num_features = decoder_NF.shape
        assert num_features > num_neurons

        neural_feature_matrix_FF = einsum(
            decoder_NF_normed,
            decoder_NF_normed.detach(),
            "num_neurons num_features1, num_neurons num_features2 -> num_features1 num_features2",
        )
        neural_feature_matrix_FF = neural_feature_matrix_FF - t.eye(num_features).to(
            decoder_NF
        )

        # print(neural_feature_matrix_FF)

        decorrelation_score: t.Tensor = t.sum(neural_feature_matrix_FF**2)

        normalisation = num_features * (num_features - 1)
        decorrelation_score = decorrelation_score / normalisation

        sqrt_decorr_score = t.sqrt(decorrelation_score)

        decorr_inf_scores_F = t.max(neural_feature_matrix_FF, dim=1).values
        decorr_inf_score = t.mean(decorr_inf_scores_F)

        return sqrt_decorr_score, decorr_inf_score

    def _l0_sparsity(self, feature_activations_BSF: t.Tensor):
        batch_size, seq_len, _num_features = feature_activations_BSF.shape

        num_active_features = t.count_nonzero(feature_activations_BSF)
        active_features_per_token = num_active_features / (batch_size * seq_len)
        return active_features_per_token

    def update_feature_activity_(self, features_BSF: t.Tensor) -> None:
        feature_activity_F = t.sum(features_BSF > 0, dim=(0, 1))
        self.feature_activation_queue.append(feature_activity_F)

    def save(self, path: str, other_details: str):
        # Create directory if it doesn't exist
        path = f"{self.base_path}/{path}"
        dirname = os.path.dirname(path)
        if dirname:
            os.makedirs(os.path.dirname(path), exist_ok=True)
        try:
            # Save config to file
            self.config.other_details = other_details
            config_json_str = self.config.serialise_to_json()

            with open(path + ".json", "w") as f:
                f.write(config_json_str)

            # Save model state dict
            t.save(self.state_dict(), path + ".pt")

            logger.success(f"Model saved to {path}")

        except Exception as e:
            logger.error(f"Failed to save model to {path}, error: {e}")

    @classmethod
    def from_pretrained(
        cls: Type[Self],
        path: str,
        device: str = device,
    ) -> Self:

        # Load config from file
        with open(path + ".json", "r") as f:
            loaded_json_str = f.read()
        try:
            config = AutoEncoderConfig.from_json(loaded_json_str)
        except Exception as e:
            # Str to json
            loaded_json = json.loads(loaded_json_str)

            autoencoder_type_str = loaded_json["autoencoder_type"]["__enum__"].split(".")[-1]

            # logger.warning(loaded_json["autoencoder_type"]["__enum__"])
            # logger.warning(AutoEncoderType[autoencoder_type_str])
            # logger.warning(type(autoencoder_type_str))

            config = AutoEncoderConfig(
                num_neurons=loaded_json["num_neurons"],
                num_features=loaded_json["num_features"],
                autoencoder_type=AutoEncoderType[autoencoder_type_str],
            )

            logger.warning(e)
            logger.warning("Using default config")

        model = cls(
            config,
            medoid_initial_tensor_N=None,
            preprocess_scaling_factor=None,
            device=device,
        )
        model.load_state_dict(t.load(path + ".pt"))

        return model

    @abstractmethod
    def forward(
        self,
        x: t.Tensor,
        output_intermediate_activations: bool = False,
        output_supermodel_output: bool = False,
    ) -> Any:
        """Forward method for the autoencoder. To be subclassed.

        The basic form is:

        1. x_center = self.preprocess(x)
        2. pre_activation_features = self.encode(x_center)
        3. features = self.activation_fn(pre_activation_features)
        4. raw_x_hat = self.decode(features)
        5. x_hat = self.postprocess(raw_x_hat)

        The goal of the autoencoder is to learn the identity function x -> x_hat.

        Parameters
        ----------
        x : t.Tensor
            _description_
        output_intermediate_activations : bool, optional
            _description_, by default False

        Returns
        -------
        Any
            _description_
        """
        ...
