"""
This file originates from the saebench repository.
We have modified it to incorporate the correct hookpoint for FFKV/Transcoder evaluations.
Modified lines are indicated with # MODIFIED.
"""

import argparse
import gc
import logging
import math
import os
import subprocess
import time
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import asdict, dataclass, field
from functools import partial
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Any

import einops
import torch
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_lens.training.activations_store import ActivationsStore
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookedRootModule

import ff_kv_sae.sae_bench_utils.general_utils as general_utils
import ff_kv_sae.sae_bench_utils.sae_selection_utils as sae_selection_utils
from ff_kv_sae.evals.core.eval_config import CoreEvalConfig
from ff_kv_sae.evals.core.eval_output import (
    CoreEvalOutput,
    CoreFeatureMetric,
    CoreMetricCategories,
    MiscMetrics,
    ModelBehaviorPreservationMetrics,
    ModelPerformancePreservationMetrics,
    ReconstructionQualityMetrics,
    ShrinkageMetrics,
    SparsityMetrics,
    TokenStatsMetrics,
)
from ff_kv_sae.sae_bench_utils import (
    get_eval_uuid,
    get_sae_bench_version,
    get_sae_lens_version,
)

logger = logging.getLogger(__name__)

# you can truncate to save space/bandwidth, but be warned that this will
# likely screw up the feature density metrics among others. 10 is a good
# compromise.
DEFAULT_FLOAT_PRECISION = 10


def get_library_version() -> str:
    try:
        return version("sae_lens")
    except PackageNotFoundError:
        return "unknown"


def get_git_hash() -> str:
    """
    Retrieves the current Git commit hash.
    Returns 'unknown' if the hash cannot be determined.
    """
    try:
        # Ensure the command is run in the directory where .git exists
        git_dir = Path(__file__).resolve().parent.parent  # Adjust if necessary
        result = subprocess.run(  # noqa: UP022
            ["git", "rev-parse", "--short", "HEAD"],
            cwd=git_dir,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True,
        )
        return result.stdout.strip()
    except (subprocess.CalledProcessError, FileNotFoundError, OSError):
        return "unknown"


# Everything by default is false so the user can just set the ones they want to true
@dataclass
class MultipleEvalsConfig:
    batch_size_prompts: int | None = None
    n_eval_reconstruction_batches: int = 10
    n_eval_sparsity_variance_batches: int = 1
    compute_kl: bool = False
    compute_ce_loss: bool = False
    compute_l2_norms: bool = False
    compute_sparsity_metrics: bool = False
    compute_variance_metrics: bool = False
    library_version: str = field(default_factory=get_library_version)
    git_hash: str = field(default_factory=get_git_hash)


def get_multiple_evals_everything_config(
    batch_size_prompts: int | None = None,
    n_eval_reconstruction_batches: int = 10,
    n_eval_sparsity_variance_batches: int = 1,
) -> MultipleEvalsConfig:
    """
    Returns a MultipleEvalsConfig object with all metrics set to True
    """
    return MultipleEvalsConfig(
        batch_size_prompts=batch_size_prompts,
        n_eval_reconstruction_batches=n_eval_reconstruction_batches,
        compute_kl=True,
        compute_ce_loss=True,
        compute_l2_norms=True,
        n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches,
        compute_sparsity_metrics=True,
        compute_variance_metrics=True,
    )


@torch.no_grad()
def run_evals(
    sae: SAE,
    activation_store: ActivationsStore,
    model: HookedRootModule,
    eval_config: CoreEvalConfig = CoreEvalConfig(),
    model_kwargs: Mapping[str, Any] = {},
    ignore_tokens: set[int | None] = set(),
    verbose: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:
    # MODIFIED
    hook_name = sae.cfg.hook_name
    print(f"INFO: Core evaluation using hook point: {hook_name}")
    actual_batch_size = (
        eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
    )

    # TODO: Come up with a cleaner long term strategy here for SAEs that do reshaping.
    # turn off hook_z reshaping mode if it's on, and restore it after evals
    if "hook_z" in hook_name:
        previous_hook_z_reshaping_mode = sae.hook_z_reshaping_mode
        sae.turn_off_forward_pass_hook_z_reshaping()
    else:
        previous_hook_z_reshaping_mode = None

    all_metrics = {
        "model_behavior_preservation": {},
        "model_performance_preservation": {},
        "reconstruction_quality": {},
        "shrinkage": {},
        "sparsity": {},
        "token_stats": {},
    }

    if eval_config.compute_kl or eval_config.compute_ce_loss:
        assert eval_config.n_eval_reconstruction_batches > 0
        reconstruction_metrics = get_downstream_reconstruction_metrics(
            sae,
            model,
            activation_store,
            compute_kl=eval_config.compute_kl,
            compute_ce_loss=eval_config.compute_ce_loss,
            n_batches=eval_config.n_eval_reconstruction_batches,
            eval_batch_size_prompts=actual_batch_size,
            ignore_tokens=ignore_tokens,
            exclude_special_tokens_from_reconstruction=eval_config.exclude_special_tokens_from_reconstruction,
            verbose=verbose,
        )

        if eval_config.compute_kl:
            all_metrics["model_behavior_preservation"].update(
                {
                    "kl_div_score": reconstruction_metrics["kl_div_score"],
                    "kl_div_with_ablation": reconstruction_metrics[
                        "kl_div_with_ablation"
                    ],
                    "kl_div_with_sae": reconstruction_metrics["kl_div_with_sae"],
                }
            )

        if eval_config.compute_ce_loss:
            all_metrics["model_performance_preservation"].update(
                {
                    "ce_loss_score": reconstruction_metrics["ce_loss_score"],
                    "ce_loss_with_ablation": reconstruction_metrics[
                        "ce_loss_with_ablation"
                    ],
                    "ce_loss_with_sae": reconstruction_metrics["ce_loss_with_sae"],
                    "ce_loss_without_sae": reconstruction_metrics[
                        "ce_loss_without_sae"
                    ],
                }
            )

        activation_store.reset_input_dataset()

    if (
        eval_config.compute_l2_norms
        or eval_config.compute_sparsity_metrics
        or eval_config.compute_variance_metrics
    ):
        assert eval_config.n_eval_sparsity_variance_batches > 0
        sparsity_variance_metrics, feature_metrics = get_sparsity_and_variance_metrics(
            sae,
            model,
            activation_store,
            compute_l2_norms=eval_config.compute_l2_norms,
            compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
            compute_variance_metrics=eval_config.compute_variance_metrics,
            compute_featurewise_density_statistics=eval_config.compute_featurewise_density_statistics,
            n_batches=eval_config.n_eval_sparsity_variance_batches,
            eval_batch_size_prompts=actual_batch_size,
            model_kwargs=model_kwargs,
            ignore_tokens=ignore_tokens,
            verbose=verbose,
        )

        if eval_config.compute_l2_norms:
            all_metrics["shrinkage"].update(
                {
                    "l2_norm_in": sparsity_variance_metrics["l2_norm_in"],
                    "l2_norm_out": sparsity_variance_metrics["l2_norm_out"],
                    "l2_ratio": sparsity_variance_metrics["l2_ratio"],
                    "relative_reconstruction_bias": sparsity_variance_metrics[
                        "relative_reconstruction_bias"
                    ],
                }
            )

        if eval_config.compute_sparsity_metrics:
            all_metrics["sparsity"].update(
                {
                    "l0": sparsity_variance_metrics["l0"],
                    "l1": sparsity_variance_metrics["l1"],
                }
            )

        if eval_config.compute_variance_metrics:
            all_metrics["reconstruction_quality"].update(
                {
                    "explained_variance": sparsity_variance_metrics[
                        "explained_variance"
                    ],
                    "explained_variance_legacy": sparsity_variance_metrics[
                        "explained_variance_legacy"
                    ],
                    "mse": sparsity_variance_metrics["mse"],
                    "cossim": sparsity_variance_metrics["cossim"],
                }
            )
    else:
        feature_metrics = {}

    if eval_config.compute_featurewise_weight_based_metrics:
        feature_metrics |= get_featurewise_weight_based_metrics(sae)

    if len(all_metrics) == 0:
        raise ValueError(
            "No metrics were computed, please set at least one metric to True."
        )

    # restore previous hook z reshaping mode if necessary
    if "hook_z" in hook_name:
        if previous_hook_z_reshaping_mode and not sae.hook_z_reshaping_mode:
            sae.turn_on_forward_pass_hook_z_reshaping()
        elif not previous_hook_z_reshaping_mode and sae.hook_z_reshaping_mode:
            sae.turn_off_forward_pass_hook_z_reshaping()

    total_tokens_evaluated_eval_reconstruction = (
        activation_store.context_size
        * eval_config.n_eval_reconstruction_batches
        * actual_batch_size
    )

    total_tokens_evaluated_eval_sparsity_variance = (
        activation_store.context_size
        * eval_config.n_eval_sparsity_variance_batches
        * actual_batch_size
    )

    all_metrics["token_stats"] = {
        "total_tokens_eval_reconstruction": total_tokens_evaluated_eval_reconstruction,
        "total_tokens_eval_sparsity_variance": total_tokens_evaluated_eval_sparsity_variance,
    }

    # Remove empty metric groups
    all_metrics = {k: v for k, v in all_metrics.items() if v}

    return all_metrics, feature_metrics


def calculate_max_cosine_sim(
    encoder_DF: torch.Tensor, batch_size: int = 100
) -> torch.Tensor:
    """
    encoder_DF: Tensor of shape (D, F)
                where D = dimension of each feature
                and F = number of features
    batch_size: The number of columns processed in each chunk.

    Returns:
    max_sims: A tensor of shape (F,) where each entry i is the
              maximum cosine similarity of column i with any other column.
    """
    # 1) Normalize columns so each feature vector has unit norm.
    enc_norm_DF = torch.nn.functional.normalize(encoder_DF, p=2, dim=0)

    F_ = enc_norm_DF.shape[1]  # Number of features

    max_sims_F = torch.empty(F_, dtype=enc_norm_DF.dtype, device=enc_norm_DF.device)

    # 2) Process columns in batches to avoid creating an F x F matrix
    for start in range(0, F_, batch_size):
        end = min(start + batch_size, F_)

        chunk_DC = enc_norm_DF[:, start:end]

        # 3) Compute cosine similarity between this chunk and ALL columns.
        sims_CF = chunk_DC.t() @ enc_norm_DF

        # 4) Ignore self-similarity on the diagonal for columns in [start, end).
        #    We set those diagonal positions to -inf.
        for col_idx in range(start, end):
            sims_CF[col_idx - start, col_idx] = float("-inf")

        # 5) Take the max over each row in the chunk.
        row_max_sims_C = sims_CF.max(dim=1).values

        # Store the result for this batch
        max_sims_F[start:end] = row_max_sims_C

    return max_sims_F


def calculate_mean_cosine_sim(
    encoder_DF: torch.Tensor, batch_size: int = 100
) -> torch.Tensor:
    """
    encoder_DF: Tensor of shape (D, F)
                where D = dimension of each feature
                and F = number of features
    batch_size: The number of columns processed in each chunk.

    Returns:
    mean_sims: A tensor of shape (F,) where each entry i is the
              mean cosine similarity of column i with all other columns.
    """
    # 1) Normalize columns so each feature vector has unit norm.
    enc_norm_DF = torch.nn.functional.normalize(encoder_DF, p=2, dim=0)

    F_ = enc_norm_DF.shape[1]  # Number of features

    mean_sims_F = torch.empty(F_, dtype=enc_norm_DF.dtype, device=enc_norm_DF.device)

    # 2) Process columns in batches to avoid creating an F x F matrix
    for start in range(0, F_, batch_size):
        end = min(start + batch_size, F_)

        chunk_DC = enc_norm_DF[:, start:end]

        # 3) Compute cosine similarity between this chunk and ALL columns.
        sims_CF = chunk_DC.t() @ enc_norm_DF

        # 4) Ignore self-similarity on the diagonal for columns in [start, end).
        #    We set those diagonal positions to -inf.
        for col_idx in range(start, end):
            sims_CF[col_idx - start, col_idx] = float("-inf")

        # 5) Take the mean over each row in the chunk, excluding -inf values.
        row_mean_sims_C = torch.zeros(
            end - start, dtype=enc_norm_DF.dtype, device=enc_norm_DF.device
        )
        for i in range(end - start):
            mask = sims_CF[i] != float("-inf")
            if mask.sum() > 0:  # Ensure we're not dividing by zero
                row_mean_sims_C[i] = sims_CF[i][mask].mean()

        # Store the result for this batch
        mean_sims_F[start:end] = row_mean_sims_C

    return mean_sims_F


def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
    unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
    unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()

    encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()

    # gated models have a different bias (no b_enc)
    if not hasattr(sae, "b_enc") and not hasattr(sae, "b_mag"):
        encoder_bias = torch.zeros(sae.cfg.d_sae).cpu().tolist()
    elif sae.cfg.architecture != "gated":
        encoder_bias = sae.b_enc.cpu().tolist()
    else:
        encoder_bias = sae.b_mag.cpu().tolist()

    encoder_decoder_cosine_sim = (
        torch.nn.functional.cosine_similarity(
            unit_norm_decoder.T,
            unit_norm_encoders.T,
        )
        .cpu()
        .tolist()
    )

    max_encoder_cosine_sim = calculate_max_cosine_sim(sae.W_enc).cpu().tolist()
    max_decoder_cosine_sim = calculate_max_cosine_sim(sae.W_dec.T).cpu().tolist()
    mean_encoder_cosine_sim = calculate_mean_cosine_sim(sae.W_enc).cpu().tolist()
    mean_decoder_cosine_sim = calculate_mean_cosine_sim(sae.W_dec.T).cpu().tolist()

    return {
        "encoder_bias": encoder_bias,
        "encoder_norm": encoder_norms,
        "encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
        "max_encoder_cosine_sim": max_encoder_cosine_sim,
        "max_decoder_cosine_sim": max_decoder_cosine_sim,
        "mean_encoder_cosine_sim": mean_encoder_cosine_sim,
        "mean_decoder_cosine_sim": mean_decoder_cosine_sim,
    }


def get_downstream_reconstruction_metrics(
    sae: SAE,
    model: HookedRootModule,
    activation_store: ActivationsStore,
    compute_kl: bool,
    compute_ce_loss: bool,
    n_batches: int,
    eval_batch_size_prompts: int,
    ignore_tokens: set[int | None] = set(),
    exclude_special_tokens_from_reconstruction: bool = False,
    verbose: bool = False,
):
    metrics_dict = {}
    if compute_kl:
        metrics_dict["kl_div_with_sae"] = []
        metrics_dict["kl_div_with_ablation"] = []
    if compute_ce_loss:
        metrics_dict["ce_loss_with_sae"] = []
        metrics_dict["ce_loss_without_sae"] = []
        metrics_dict["ce_loss_with_ablation"] = []

    batch_iter = range(n_batches)
    if verbose:
        batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")

    for _ in batch_iter:
        batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
        for metric_name, metric_value in get_recons_loss(
            sae,
            model,
            batch_tokens,
            activation_store,
            compute_kl=compute_kl,
            compute_ce_loss=compute_ce_loss,
            ignore_tokens=ignore_tokens,
            exclude_special_tokens_from_reconstruction=exclude_special_tokens_from_reconstruction,
        ).items():
            if len(ignore_tokens) > 0:
                mask = torch.logical_not(
                    torch.any(
                        torch.stack(
                            [batch_tokens == token for token in ignore_tokens], dim=0
                        ),
                        dim=0,
                    )
                )
                if metric_value.shape[1] != mask.shape[1]:
                    # ce loss will be missing the last value
                    mask = mask[:, :-1]
                metric_value = metric_value[mask]

            metrics_dict[metric_name].append(metric_value)

    metrics: dict[str, float] = {}
    for metric_name, metric_values in metrics_dict.items():
        metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()

    if compute_kl and sae:
        metrics["kl_div_score"] = (
            metrics["kl_div_with_ablation"] - metrics["kl_div_with_sae"]
        ) / metrics["kl_div_with_ablation"]

    if compute_ce_loss:
        metrics["ce_loss_score"] = (
            metrics["ce_loss_with_ablation"] - metrics["ce_loss_with_sae"]
        ) / (metrics["ce_loss_with_ablation"] - metrics["ce_loss_without_sae"])

    return metrics


def get_sparsity_and_variance_metrics(
    sae: SAE,
    model: HookedRootModule,
    activation_store: ActivationsStore,
    n_batches: int,
    compute_l2_norms: bool,
    compute_sparsity_metrics: bool,
    compute_variance_metrics: bool,
    compute_featurewise_density_statistics: bool,
    eval_batch_size_prompts: int,
    model_kwargs: Mapping[str, Any],
    ignore_tokens: set[int | None] = set(),
    verbose: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:
    hook_name = sae.cfg.hook_name
    hook_head_index = sae.cfg.hook_head_index

    is_special_sae = False
    if hasattr(sae.cfg, "architecture") and isinstance(sae.cfg.architecture, str):
        is_special_sae = any(
            t in sae.cfg.architecture.lower()
            for t in ["mlp_neuron_sae", "sparse_mlp_neuron_sae", "transcoder"]
        )
    if not is_special_sae and hasattr(sae.cfg, "hook_name"):
        is_special_sae = "ln2.hook_normalized" in str(sae.cfg.hook_name).lower()

    mlp_out_hook = (
        f"blocks.{sae.cfg.hook_layer}.hook_mlp_out" if is_special_sae else None
    )

    metric_dict = {}
    feature_metric_dict = {}

    if compute_l2_norms:
        metric_dict["l2_norm_in"] = []
        metric_dict["l2_norm_out"] = []
        metric_dict["l2_ratio"] = []
        metric_dict["relative_reconstruction_bias"] = []
    if compute_sparsity_metrics:
        metric_dict["l0"] = []
        metric_dict["l1"] = []
    if compute_variance_metrics:
        metric_dict["explained_variance"] = []
        metric_dict["explained_variance_legacy"] = []
        mean_sum_of_squares = []  # for explained variance
        mean_act_per_dimension = []  # for explained variance
        mean_sum_of_resid_squared = []  # for explained variance
        metric_dict["mse"] = []
        metric_dict["cossim"] = []
    if compute_featurewise_density_statistics:
        feature_metric_dict["feature_density"] = []
        feature_metric_dict["consistent_activation_heuristic"] = []

    total_feature_acts = torch.zeros(sae.cfg.d_sae, device=sae.device)
    total_feature_prompts = torch.zeros(sae.cfg.d_sae, device=sae.device)
    total_tokens = 0

    batch_iter = range(n_batches)
    if verbose:
        batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")

    for _ in batch_iter:
        batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)

        if len(ignore_tokens) > 0:
            mask = torch.logical_not(
                torch.any(
                    torch.stack(
                        [batch_tokens == token for token in ignore_tokens], dim=0
                    ),
                    dim=0,
                )
            )
        else:
            mask = torch.ones_like(batch_tokens, dtype=torch.bool)
        flattened_mask = mask.flatten()

        # Get cache with appropriate hooks
        cache_hooks = None  # Get all hooks by default
        if is_special_sae and mlp_out_hook:
            cache_hooks = [hook_name, mlp_out_hook]

        # get cache
        _, cache = model.run_with_cache(
            batch_tokens,
            prepend_bos=False,
            names_filter=cache_hooks,
            stop_at_layer=sae.cfg.hook_layer + 1,
            **model_kwargs,
        )

        # Continue with the original hook we need
        if hook_name not in cache:
            raise ValueError(f"Hook {hook_name} not found in available hooks")

        # we would include hook z, except that we now have base SAE's
        # which will do their own reshaping for hook z.
        has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"]
        if hook_head_index is not None:
            original_act = cache[hook_name][:, :, hook_head_index]
        elif any(substring in hook_name for substring in has_head_dim_key_substrings):
            original_act = cache[hook_name].flatten(-2, -1)
        else:
            original_act = cache[hook_name]

        # MODIFIED: For FFKV/Transcoder SAEs, get the FFKV output for comparison instead of input
        if is_special_sae and mlp_out_hook and mlp_out_hook in cache:
            comparison_act = cache[mlp_out_hook]
            if hook_head_index is not None:
                comparison_act = comparison_act[:, :, hook_head_index]
            elif any(
                substring in hook_name for substring in has_head_dim_key_substrings
            ):
                comparison_act = comparison_act.flatten(-2, -1)
        else:
            comparison_act = original_act

        # check datatype for act
        if original_act.dtype != sae.dtype:
            original_act = original_act.to(sae.dtype)
            if id(comparison_act) != id(
                original_act
            ):  # Only convert if different objects
                comparison_act = comparison_act.to(sae.dtype)

        # normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
        if activation_store.normalize_activations == "expected_average_only_in":
            original_act = activation_store.apply_norm_scaling_factor(original_act)

        # send the (maybe normalised) activations into the SAE
        sae_feature_activations = sae.encode(original_act.to(sae.device))
        sae_out = sae.decode(sae_feature_activations).to(original_act.device)
        del cache

        if activation_store.normalize_activations == "expected_average_only_in":
            sae_out = activation_store.unscale(sae_out)

        # Flatten for easier processing
        flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
        flattened_comparison_act = einops.rearrange(
            comparison_act, "b ctx d -> (b ctx) d"
        )
        flattened_sae_feature_acts = einops.rearrange(
            sae_feature_activations, "b ctx d -> (b ctx) d"
        )
        flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d")

        # TODO: Clean this up.
        # apply mask
        masked_sae_feature_activations = sae_feature_activations * mask.unsqueeze(-1)
        flattened_sae_input = flattened_sae_input[flattened_mask]
        flattened_comparison_act = flattened_comparison_act[flattened_mask]
        flattened_sae_feature_acts = flattened_sae_feature_acts[flattened_mask]
        flattened_sae_out = flattened_sae_out[flattened_mask]

        if compute_l2_norms:
            # For L2 norms, we use the input and output directly
            l2_norm_in = torch.norm(flattened_sae_input, dim=-1)
            l2_norm_out = torch.norm(flattened_sae_out, dim=-1)
            l2_norm_in_for_div = l2_norm_in.clone()
            l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1
            l2_norm_ratio = l2_norm_out / l2_norm_in_for_div

            # Equation 10 from https://arxiv.org/abs/2404.16014
            # https://github.com/saprmarks/dictionary_learning/blob/main/evaluation.py
            x_hat_norm_squared = torch.norm(flattened_sae_out, dim=-1) ** 2
            x_dot_x_hat = (flattened_sae_input * flattened_sae_out).sum(dim=-1)
            relative_reconstruction_bias = (
                x_hat_norm_squared.mean() / x_dot_x_hat.mean()
            ).unsqueeze(0)

            metric_dict["l2_norm_in"].append(l2_norm_in)
            metric_dict["l2_norm_out"].append(l2_norm_out)
            metric_dict["l2_ratio"].append(l2_norm_ratio)
            metric_dict["relative_reconstruction_bias"].append(
                relative_reconstruction_bias
            )

        if compute_sparsity_metrics:
            l0 = (flattened_sae_feature_acts != 0).sum(dim=-1).float()
            l1 = flattened_sae_feature_acts.sum(dim=-1)
            metric_dict["l0"].append(l0)
            metric_dict["l1"].append(l1)

        if compute_variance_metrics:
            # For variance metrics, compare reconstruction against the target activations
            # (either original or mlp_out for special SAEs)
            resid_sum_of_squares = (
                (flattened_comparison_act - flattened_sae_out).pow(2).sum(dim=-1)
            )
            mse = resid_sum_of_squares / flattened_mask.sum()

            # Explained variance (old, incorrect, formula)
            batched_variance_sum = (
                (flattened_comparison_act - flattened_comparison_act.mean(dim=0))
                .pow(2)
                .sum(dim=-1)
            )
            explained_variance_legacy = 1 - resid_sum_of_squares / batched_variance_sum
            metric_dict["explained_variance_legacy"].append(explained_variance_legacy)

            # Individual sums for the new (correct) formula. We're taking the mean over the batch
            # dimension here to save memory, but we could also pass the full tensors and take the
            # mean later (like we do for other metrics).
            mean_sum_of_squares.append(  # type: ignore
                (flattened_comparison_act).pow(2).sum(dim=-1).mean(dim=0)  # scalar
            )
            mean_act_per_dimension.append(  # type: ignore
                (flattened_comparison_act).pow(2).mean(dim=0)  # [d_model]
            )
            mean_sum_of_resid_squared.append(  # type: ignore
                resid_sum_of_squares.mean(dim=0)  # scalar
            )

            x_normed = flattened_comparison_act / torch.norm(
                flattened_comparison_act, dim=-1, keepdim=True
            )
            x_hat_normed = flattened_sae_out / torch.norm(
                flattened_sae_out, dim=-1, keepdim=True
            )
            cossim = (x_normed * x_hat_normed).sum(dim=-1)

            metric_dict["mse"].append(mse)
            metric_dict["cossim"].append(cossim)

        if compute_featurewise_density_statistics:
            sae_feature_activations_bool = (masked_sae_feature_activations > 0).float()
            total_feature_acts += sae_feature_activations_bool.sum(dim=1).sum(dim=0)
            total_feature_prompts += (sae_feature_activations_bool.sum(dim=1) > 0).sum(
                dim=0
            )
            total_tokens += mask.sum()

    # Aggregate scalar metrics
    metrics: dict[str, float] = {}
    for metric_name, metric_values in metric_dict.items():
        if metric_name != "explained_variance":
            metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()
        else:
            mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0)  # type: ignore
            mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0)  # type: ignore
            total_variance = mean_sum_of_squares - mean_act_per_dimension**2
            residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)  # type: ignore
            metrics["explained_variance"] = 1 - residual_variance / total_variance  # type: ignore

    # Aggregate feature-wise metrics
    feature_metrics: dict[str, list[float]] = {}
    feature_metrics["feature_density"] = (total_feature_acts / total_tokens).tolist()
    feature_metrics["consistent_activation_heuristic"] = (
        total_feature_acts / total_feature_prompts
    ).tolist()

    return metrics, feature_metrics


@torch.no_grad()
def get_recons_loss(
    sae: SAE,
    model: HookedRootModule,
    batch_tokens: torch.Tensor,
    activation_store: ActivationsStore,
    compute_kl: bool,
    compute_ce_loss: bool,
    ignore_tokens: set[int | None] = set(),
    exclude_special_tokens_from_reconstruction: bool = False,
    model_kwargs: Mapping[str, Any] = {},
) -> dict[str, Any]:
    # MODIFIED
    hook_name = sae.cfg.hook_name
    head_index = sae.cfg.hook_head_index

    original_logits, original_ce_loss = model(
        batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
    )

    if len(ignore_tokens) > 0 and exclude_special_tokens_from_reconstruction:
        mask = torch.logical_not(
            torch.any(
                torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
                dim=0,
            )
        )
    else:
        mask = torch.ones_like(batch_tokens, dtype=torch.bool)

    metrics = {}

    # TODO(tomMcGrath): the rescaling below is a bit of a hack and could probably be tidied up
    def standard_replacement_hook(activations: torch.Tensor, hook: Any):
        original_device = activations.device
        activations = activations.to(sae.device)

        # Handle rescaling if SAE expects it
        if activation_store.normalize_activations == "expected_average_only_in":
            activations = activation_store.apply_norm_scaling_factor(activations)

        # SAE class agnost forward forward pass.
        reconstructed_activations = sae.decode(sae.encode(activations)).to(
            activations.dtype
        )

        # Unscale if activations were scaled prior to going into the SAE
        if activation_store.normalize_activations == "expected_average_only_in":
            reconstructed_activations = activation_store.unscale(
                reconstructed_activations
            )

        reconstructed_activations = torch.where(
            mask[..., None], reconstructed_activations, activations
        )

        return reconstructed_activations.to(original_device)

    def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
        original_device = activations.device
        activations = activations.to(sae.device)

        # Handle rescaling if SAE expects it
        if activation_store.normalize_activations == "expected_average_only_in":
            activations = activation_store.apply_norm_scaling_factor(activations)

        # SAE class agnost forward forward pass.
        new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
            activations.dtype
        )

        new_activations = new_activations.reshape(
            activations.shape
        )  # reshape to match original shape

        # Unscale if activations were scaled prior to going into the SAE
        if activation_store.normalize_activations == "expected_average_only_in":
            new_activations = activation_store.unscale(new_activations)

        # Apply mask to keep original activations for ignored tokens
        new_activations = torch.where(
            mask[..., None, None], new_activations, activations
        )

        return new_activations.to(original_device)

    def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
        original_device = activations.device
        activations = activations.to(sae.device)

        # Handle rescaling if SAE expects it
        if activation_store.normalize_activations == "expected_average_only_in":
            activations = activation_store.apply_norm_scaling_factor(activations)

        # Create a copy of activations to modify
        new_activations = activations.clone()

        # Only reconstruct the specified head
        head_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
            activations.dtype
        )

        # Apply mask only to the reconstructed head
        masked_head_activations = torch.where(
            mask[..., None], head_activations, activations[:, :, head_index]
        )
        new_activations[:, :, head_index] = masked_head_activations

        # Unscale if activations were scaled prior to going into the SAE
        if activation_store.normalize_activations == "expected_average_only_in":
            new_activations = activation_store.unscale(new_activations)

        return new_activations.to(original_device)

    def standard_zero_ablate_hook(activations: torch.Tensor, hook: Any):
        original_device = activations.device
        activations = activations.to(sae.device)
        activations = torch.zeros_like(activations)
        return activations.to(original_device)

    def single_head_zero_ablate_hook(activations: torch.Tensor, hook: Any):
        original_device = activations.device
        activations = activations.to(sae.device)
        activations[:, :, head_index] = torch.zeros_like(activations[:, :, head_index])
        return activations.to(original_device)

    # we would include hook z, except that we now have base SAE's
    # which will do their own reshaping for hook z.
    has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"]
    if any(substring in hook_name for substring in has_head_dim_key_substrings):
        if head_index is None:
            replacement_hook = all_head_replacement_hook
            zero_ablate_hook = standard_zero_ablate_hook
        else:
            replacement_hook = single_head_replacement_hook
            zero_ablate_hook = single_head_zero_ablate_hook
    else:
        replacement_hook = standard_replacement_hook
        zero_ablate_hook = standard_zero_ablate_hook

    recons_logits, recons_ce_loss = model.run_with_hooks(
        batch_tokens,
        return_type="both",
        fwd_hooks=[(hook_name, partial(replacement_hook))],
        loss_per_token=True,
        **model_kwargs,
    )

    zero_abl_logits, zero_abl_ce_loss = model.run_with_hooks(
        batch_tokens,
        return_type="both",
        fwd_hooks=[(hook_name, zero_ablate_hook)],
        loss_per_token=True,
        **model_kwargs,
    )

    def kl(original_logits: torch.Tensor, new_logits: torch.Tensor):
        original_probs = torch.nn.functional.softmax(original_logits, dim=-1)
        log_original_probs = torch.log(original_probs)
        new_probs = torch.nn.functional.softmax(new_logits, dim=-1)
        log_new_probs = torch.log(new_probs)
        kl_div = original_probs * (log_original_probs - log_new_probs)
        kl_div = kl_div.sum(dim=-1)
        return kl_div

    if compute_kl:
        recons_kl_div = kl(original_logits, recons_logits)
        zero_abl_kl_div = kl(original_logits, zero_abl_logits)
        metrics["kl_div_with_sae"] = recons_kl_div
        metrics["kl_div_with_ablation"] = zero_abl_kl_div

    if compute_ce_loss:
        metrics["ce_loss_with_sae"] = recons_ce_loss
        metrics["ce_loss_without_sae"] = original_ce_loss
        metrics["ce_loss_with_ablation"] = zero_abl_ce_loss

    return metrics


def all_loadable_saes() -> list[tuple[str, str, float, float]]:
    all_loadable_saes = []
    saes_directory = get_pretrained_saes_directory()
    for release, lookup in tqdm(saes_directory.items()):
        for sae_name in lookup.saes_map.keys():
            expected_var_explained = lookup.expected_var_explained[sae_name]
            expected_l0 = lookup.expected_l0[sae_name]
            all_loadable_saes.append(
                (release, sae_name, expected_var_explained, expected_l0)
            )

    return all_loadable_saes


def nested_dict() -> defaultdict[Any, Any]:
    return defaultdict(nested_dict)


def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
    nested = nested_dict()
    for key, value in flat_dict.items():
        parts = key.split("/")
        d = nested
        for part in parts[:-1]:
            d = d[part]
        d[parts[-1]] = value
    return nested


def convert_feature_metrics(
    flattened_feature_metrics: dict[str, list[float]],
) -> list[CoreFeatureMetric]:
    """Convert feature metrics from parallel lists to list of dicts.

    Args:
        flattened_feature_metrics: Dict mapping metric names to lists of values

    Returns:
        List of CoreFeatureMetric objects, one per feature
    """
    feature_metrics_by_feature = []
    if flattened_feature_metrics:
        num_features = len(flattened_feature_metrics["consistent_activation_heuristic"])
        for i in range(num_features):
            feature_metrics_by_feature.append(
                CoreFeatureMetric(
                    index=i,
                    consistent_activation_heuristic=round(
                        flattened_feature_metrics["consistent_activation_heuristic"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    encoder_bias=round(
                        flattened_feature_metrics["encoder_bias"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    encoder_decoder_cosine_sim=round(
                        flattened_feature_metrics["encoder_decoder_cosine_sim"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    encoder_norm=round(
                        flattened_feature_metrics["encoder_norm"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    feature_density=round(
                        flattened_feature_metrics["feature_density"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    max_decoder_cosine_sim=round(
                        flattened_feature_metrics["max_decoder_cosine_sim"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    max_encoder_cosine_sim=round(
                        flattened_feature_metrics["max_encoder_cosine_sim"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    mean_decoder_cosine_sim=round(
                        flattened_feature_metrics["mean_decoder_cosine_sim"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                    mean_encoder_cosine_sim=round(
                        flattened_feature_metrics["mean_encoder_cosine_sim"][i],
                        DEFAULT_FLOAT_PRECISION,
                    ),
                )
            )
    return feature_metrics_by_feature


def save_single_eval_result(
    result: dict[str, Any],
    eval_instance_id: str,
    sae_lens_version: str,
    sae_bench_commit_hash: str,
    json_path: str,
    sae: SAE,
) -> str:
    """Save a single evaluation result to a JSON file."""
    # Get the eval_config directly - it's already a CoreEvalConfig object
    eval_config = result["eval_cfg"]

    # Create metric categories with default values for required fields
    model_behavior_preservation_data = result["metrics"].get(
        "model_behavior_preservation", {}
    )
    # Add default values for required fields if they don't exist
    if "kl_div_score" not in model_behavior_preservation_data:
        model_behavior_preservation_data["kl_div_score"] = -999
    if "kl_div_with_ablation" not in model_behavior_preservation_data:
        model_behavior_preservation_data["kl_div_with_ablation"] = -999
    if "kl_div_with_sae" not in model_behavior_preservation_data:
        model_behavior_preservation_data["kl_div_with_sae"] = -999

    model_performance_preservation_data = result["metrics"].get(
        "model_performance_preservation", {}
    )
    # Add default values for required fields if they don't exist
    if "ce_loss_score" not in model_performance_preservation_data:
        model_performance_preservation_data["ce_loss_score"] = -999
    if "ce_loss_with_ablation" not in model_performance_preservation_data:
        model_performance_preservation_data["ce_loss_with_ablation"] = -999
    if "ce_loss_with_sae" not in model_performance_preservation_data:
        model_performance_preservation_data["ce_loss_with_sae"] = -999
    if "ce_loss_without_sae" not in model_performance_preservation_data:
        model_performance_preservation_data["ce_loss_without_sae"] = -999

    metric_categories = CoreMetricCategories(
        model_behavior_preservation=ModelBehaviorPreservationMetrics(
            **model_behavior_preservation_data
        ),
        model_performance_preservation=ModelPerformancePreservationMetrics(
            **model_performance_preservation_data
        ),
        reconstruction_quality=ReconstructionQualityMetrics(
            **result["metrics"].get("reconstruction_quality", {})
        ),
        shrinkage=ShrinkageMetrics(**result["metrics"].get("shrinkage", {})),
        sparsity=SparsityMetrics(**result["metrics"].get("sparsity", {})),
        token_stats=TokenStatsMetrics(**result["metrics"].get("token_stats", {})),
        misc_metrics=MiscMetrics(**result["metrics"].get("misc_metrics", {})),
    )

    # Create feature metrics
    flattened_feature_metrics = result.get("feature_metrics", {})

    # Convert feature metrics from parallel lists to list of dicts
    feature_metrics_by_feature = convert_feature_metrics(flattened_feature_metrics)

    # Create the full output object
    eval_output = CoreEvalOutput(
        eval_config=eval_config,
        eval_id=eval_instance_id,
        datetime_epoch_millis=int(time.time() * 1000),
        eval_result_metrics=metric_categories,
        eval_result_details=feature_metrics_by_feature,
        eval_result_unstructured={},  # Add empty dict for unstructured results
        sae_bench_commit_hash=sae_bench_commit_hash,
        sae_lens_id=result["sae_id"],
        sae_lens_release_id=result["sae_set"],
        sae_lens_version=sae_lens_version,
        sae_cfg_dict=asdict(sae.cfg),
    )

    eval_output.to_json_file(json_path)

    return json_path


def calculate_misc_metrics(feature_metrics: dict[str, torch.Tensor]) -> dict:
    average_max_encoder_cosine_sim = (
        torch.Tensor(feature_metrics["max_encoder_cosine_sim"]).mean().item()
    )
    average_max_decoder_cosine_sim = (
        torch.Tensor(feature_metrics["max_decoder_cosine_sim"]).mean().item()
    )
    average_mean_encoder_cosine_sim = (
        torch.Tensor(feature_metrics["mean_encoder_cosine_sim"]).mean().item()
    )
    average_mean_decoder_cosine_sim = (
        torch.Tensor(feature_metrics["mean_decoder_cosine_sim"]).mean().item()
    )

    feature_densities_F = torch.Tensor(feature_metrics["feature_density"])
    feature_densities_F = feature_densities_F.float().clone().detach()

    frac_alive = (feature_densities_F > 0).float().mean().item()

    total_sum = feature_densities_F.sum()

    freq_over_1_percent = (feature_densities_F > 0.01).float().mean().item()
    freq_over_10_percent = (feature_densities_F > 0.1).float().mean().item()

    # Sum of densities of features > 1%, divided by total sum
    if total_sum > 0:
        norm_sum_1 = feature_densities_F[feature_densities_F > 0.01].sum()
        normalized_freq_over_1_percent = (norm_sum_1 / total_sum).item()
    else:
        normalized_freq_over_1_percent = 0.0

    # Sum of densities of features > 10%, divided by total sum
    if total_sum > 0:
        norm_sum_10 = feature_densities_F[feature_densities_F > 0.1].sum()
        normalized_freq_over_10_percent = (norm_sum_10 / total_sum).item()
    else:
        normalized_freq_over_10_percent = 0.0

    return {
        "average_max_encoder_cosine_sim": average_max_encoder_cosine_sim,
        "average_max_decoder_cosine_sim": average_max_decoder_cosine_sim,
        "average_mean_encoder_cosine_sim": average_mean_encoder_cosine_sim,
        "average_mean_decoder_cosine_sim": average_mean_decoder_cosine_sim,
        "frac_alive": frac_alive,
        "freq_over_1_percent": freq_over_1_percent,
        "freq_over_10_percent": freq_over_10_percent,
        "normalized_freq_over_1_percent": normalized_freq_over_1_percent,
        "normalized_freq_over_10_percent": normalized_freq_over_10_percent,
    }


def multiple_evals(
    selected_saes: list[tuple[str, str]] | list[tuple[str, SAE]],
    n_eval_reconstruction_batches: int,
    n_eval_sparsity_variance_batches: int,
    is_random_transformer: bool,
    model_instance: HookedTransformer | None,
    eval_batch_size_prompts: int = 8,
    compute_featurewise_density_statistics: bool = False,
    compute_featurewise_weight_based_metrics: bool = False,
    exclude_special_tokens_from_reconstruction: bool = False,
    dataset: str = "Skylion007/openwebtext",
    context_size: int = 128,
    output_folder: str = "eval_results",
    verbose: bool = False,
    dtype: str = "float32",
    device: str = "cuda",
    force_rerun: bool = False,
) -> list[dict[str, Any]]:
    assert len(selected_saes) > 0, "No SAEs to evaluate"

    eval_results = []
    output_path = Path(output_folder)
    output_path.mkdir(parents=True, exist_ok=True)

    # Get evaluation metadata once at the start
    eval_instance_id = get_eval_uuid()
    sae_lens_version = get_sae_lens_version()
    sae_bench_commit_hash = get_sae_bench_version()

    multiple_evals_config = get_multiple_evals_everything_config(
        batch_size_prompts=eval_batch_size_prompts,
        n_eval_reconstruction_batches=n_eval_reconstruction_batches,
        n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches,
    )

    current_model = None
    current_model_str = None

    llm_dtype = general_utils.str_to_dtype(dtype)

    for sae_release, sae_object_or_id in tqdm(
        selected_saes, desc="Running SAE evaluation on all selected SAEs"
    ):
        sae_id, sae, sparsity = general_utils.load_and_format_sae(
            sae_release, sae_object_or_id, device
        )  # type: ignore
        sae = sae.to(device=device, dtype=llm_dtype)
        print(f"\n====== Current evaluated SAE: {sae_release} ======\n")

        sae_result_path = general_utils.get_results_filepath(
            output_folder, sae_release, sae_id, is_random_transformer
        )

        if os.path.exists(sae_result_path) and not force_rerun:
            print(f"Skipping {sae_release}_{sae_id} as results already exist")
            continue

        if "gemma-2-2b_gemma_scope_transcoder" in sae_release:
            # Wrap model loading with retry
            @general_utils.retry_with_exponential_backoff(
                retries=5,
                exceptions=(
                    Exception,
                ),  # We might want to be more specific about which exceptions to catch
                initial_delay=1.0,
                max_delay=60.0,
            )
            def load_model():
                print()
                print("=" * 100)
                print(f"Loading {sae.cfg.model_name} with fold_ln=True")
                print("=" * 100)
                print()

                return HookedTransformer.from_pretrained_no_processing(
                    sae.cfg.model_name,
                    device=device,
                    dtype=sae.W_enc.dtype,
                    fold_ln=True,
                    center_writing_weights=False,
                    center_unembed=False,
                    **sae.cfg.model_from_pretrained_kwargs,
                )

            try:
                del current_model  # type: ignore
                current_model_str = sae.cfg.model_name
                current_model = load_model()
            except Exception as e:
                logger.error(
                    f"\n❗❗❗Failed to load model {sae.cfg.model_name}: {str(e)}❗❗❗\n"
                )
                continue  # Skip this SAE and continue with the next one
        elif current_model_str != sae.cfg.model_name:
            # Wrap model loading with retry
            @general_utils.retry_with_exponential_backoff(
                retries=5,
                exceptions=(
                    Exception,
                ),  # We might want to be more specific about which exceptions to catch
                initial_delay=1.0,
                max_delay=60.0,
            )
            def load_model():
                if sae.cfg.model_name == "Llama-3.1-8B":
                    model_name_to_load = "meta-llama/Llama-3.1-8B"
                    print(f"Loading {model_name_to_load} with fold_ln=False")
                else:
                    model_name_to_load = sae.cfg.model_name
                    print()
                    print("=" * 100)
                    print(f"Loading {sae.cfg.model_name} with fold_ln=False")
                    print("=" * 100)
                    print()

                if model_instance is None:
                    return HookedTransformer.from_pretrained_no_processing(
                        model_name_to_load,
                        device=device,
                        dtype=sae.W_enc.dtype,
                        fold_ln=False,
                        center_writing_weights=False,
                        center_unembed=False,
                        **sae.cfg.model_from_pretrained_kwargs,
                    )
                else:
                    return model_instance

            try:
                del current_model  # type: ignore
                current_model_str = sae.cfg.model_name
                current_model = load_model()
            except Exception as e:
                logger.error(f"Failed to load model {sae.cfg.model_name}: {str(e)}")
                continue  # Skip this SAE and continue with the next one

        assert current_model is not None  # type: ignore

        try:
            # my fixing
            should_compute_kl_ce = not any(
                t in sae_release
                for t in ["mlp_neuron_sae", "sparse_mlp_neuron_sae", "transcoder"]
            )
            if should_compute_kl_ce:
                print(f"\n========\nComputing KL/CE for {sae_release}\n========\n")
            else:
                print(f"\n========\nSkipping KL/CE for {sae_release}\n========\n")

            print(f"Context size: {context_size}")

            core_eval_config = CoreEvalConfig(
                model_name=sae.cfg.model_name,
                batch_size_prompts=multiple_evals_config.batch_size_prompts or 16,
                n_eval_reconstruction_batches=multiple_evals_config.n_eval_reconstruction_batches,
                n_eval_sparsity_variance_batches=multiple_evals_config.n_eval_sparsity_variance_batches,
                exclude_special_tokens_from_reconstruction=exclude_special_tokens_from_reconstruction,
                dataset=dataset,
                context_size=context_size,
                compute_kl=should_compute_kl_ce and multiple_evals_config.compute_kl,
                compute_ce_loss=should_compute_kl_ce
                and multiple_evals_config.compute_ce_loss,
                compute_l2_norms=multiple_evals_config.compute_l2_norms,
                compute_sparsity_metrics=multiple_evals_config.compute_sparsity_metrics,
                compute_variance_metrics=multiple_evals_config.compute_variance_metrics,
                compute_featurewise_density_statistics=compute_featurewise_density_statistics,
                compute_featurewise_weight_based_metrics=compute_featurewise_weight_based_metrics,
                llm_dtype=dtype,
            )

            # Wrap activation store creation with retry
            @general_utils.retry_with_exponential_backoff(
                retries=3,
                exceptions=(Exception,),
                initial_delay=1.0,
                max_delay=30.0,
            )
            def create_activation_store():
                return ActivationsStore.from_sae(
                    current_model,  # type: ignore
                    sae,
                    context_size=context_size,
                    dataset=dataset,
                )

            activation_store = create_activation_store()
            activation_store.shuffle_input_dataset(seed=42)

            eval_metrics = nested_dict()
            eval_metrics["unique_id"] = f"{sae_release}_{sae_id}"
            eval_metrics["sae_set"] = f"{sae_release}"
            eval_metrics["sae_id"] = f"{sae_id}"
            eval_metrics["eval_cfg"] = core_eval_config

            logger.debug(
                f"Starting evaluation for SAE {sae_id} with config: {core_eval_config}"
            )

            try:
                scalar_metrics, feature_metrics = run_evals(
                    sae=sae,
                    activation_store=activation_store,
                    model=current_model,  # type: ignore
                    eval_config=core_eval_config,
                    ignore_tokens={
                        current_model.tokenizer.pad_token_id,  # type: ignore
                        current_model.tokenizer.eos_token_id,  # type: ignore
                        current_model.tokenizer.bos_token_id,  # type: ignore
                    },
                    verbose=verbose,
                )
                eval_metrics["metrics"] = scalar_metrics

                if (
                    compute_featurewise_density_statistics
                    or compute_featurewise_weight_based_metrics
                ):
                    try:
                        misc_metrics = calculate_misc_metrics(feature_metrics)
                        eval_metrics["metrics"]["misc_metrics"] = misc_metrics
                        eval_metrics["feature_metrics"] = feature_metrics
                    except Exception as e:
                        raise

                # Clean NaN values before saving
                cleaned_metrics = replace_nans_with_negative_one(eval_metrics)

                # Save results immediately after each evaluation
                saved_path = save_single_eval_result(
                    cleaned_metrics,
                    eval_instance_id,
                    sae_lens_version,
                    sae_bench_commit_hash,
                    sae_result_path,
                    sae,
                )

                if verbose:
                    print(f"Saved evaluation results to: {saved_path}")

                eval_results.append(eval_metrics)
            except Exception as e:
                print(f"Debug: Error occurred: {str(e)}")
                print(f"Debug: Error type: {type(e).__name__}")
                import traceback

                print(f"Debug: Error traceback: {traceback.format_exc()}")
                print(
                    f"Failed to evaluate SAE {sae_id} from {sae_release} "
                    f"with context length {context_size} on dataset {dataset}: {str(e)}"
                )
        except Exception as e:
            logger.error(
                f"Failed to evaluate SAE {sae_id} from {sae_release} "
                f"with context length {context_size} on dataset {dataset}: {str(e)}"
            )
            continue  # Skip this combination and continue with the next one

        gc.collect()
        torch.cuda.empty_cache()
    return eval_results


# def run_evaluations(args: argparse.Namespace) -> list[dict[str, Any]]:
#     device = general_utils.setup_environment()
#     # Filter SAEs based on regex patterns
#     filtered_saes = sae_selection_utils.get_saes_from_regex(
#         args.sae_regex_pattern, args.sae_block_pattern
#     )

#     # print the filtered SAEs
#     print("Filtered SAEs based on provided patterns:")
#     for sae in filtered_saes:
#         print(sae)

#     num_sae_sets = len(set(sae_set for sae_set, _ in filtered_saes))
#     num_all_sae_ids = len(filtered_saes)

#     print("Filtered SAEs based on provided patterns:")
#     print(f"Number of SAE sets: {num_sae_sets}")
#     print(f"Total number of SAE IDs: {num_all_sae_ids}")

#     eval_results = multiple_evals(
#         selected_saes=filtered_saes,
#         n_eval_reconstruction_batches=args.n_eval_reconstruction_batches,
#         n_eval_sparsity_variance_batches=args.n_eval_sparsity_variance_batches,
#         is_random_transformer=args.is_random_transformer,
#         eval_batch_size_prompts=args.batch_size_prompts,
#         compute_featurewise_density_statistics=True,  # TODO: Don't hardcode this
#         compute_featurewise_weight_based_metrics=True,
#         exclude_special_tokens_from_reconstruction=args.exclude_special_tokens_from_reconstruction,
#         dataset=args.dataset,
#         context_size=args.context_size,
#         output_folder=args.output_folder,
#         verbose=args.verbose,
#         dtype=args.llm_dtype,
#         device=device,
#         force_rerun=args.force_rerun,
#     )

#     return eval_results


def replace_nans_with_negative_one(obj: Any) -> Any:
    if isinstance(obj, dict):
        return {k: replace_nans_with_negative_one(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [replace_nans_with_negative_one(item) for item in obj]
    elif isinstance(obj, float) and math.isnan(obj):
        return -1
    else:
        return obj


def arg_parser():
    parser = argparse.ArgumentParser(description="Run core evaluation")
    parser.add_argument(
        "--model_name",
        type=str,
        default="",
        help="Model name. Currently this flag is ignored and the model name is inferred from sae.cfg.model_name.",
    )
    parser.add_argument(
        "sae_regex_pattern",
        type=str,
        help="Regex pattern to match SAE names. Can be an entire SAE name to match a specific SAE.",
    )
    parser.add_argument(
        "sae_block_pattern",
        type=str,
        help="Regex pattern to match SAE block names. Can be an entire block name to match a specific block.",
    )
    parser.add_argument(
        "--batch_size_prompts",
        type=int,
        default=16,
        help="Batch size for evaluation prompts.",
    )
    parser.add_argument(
        "--n_eval_reconstruction_batches",
        type=int,
        default=10,
        help="Number of evaluation batches for reconstruction metrics.",
    )
    parser.add_argument(
        "--compute_kl",
        action="store_true",
        help="Compute KL divergence.",
    )
    parser.add_argument(
        "--compute_ce_loss",
        action="store_true",
        help="Compute cross-entropy loss.",
    )
    parser.add_argument(
        "--n_eval_sparsity_variance_batches",
        type=int,
        default=1,
        help="Number of evaluation batches for sparsity and variance metrics.",
    )
    parser.add_argument(
        "--compute_l2_norms",
        action="store_true",
        help="Compute L2 norms.",
    )
    parser.add_argument(
        "--compute_sparsity_metrics",
        action="store_true",
        help="Compute sparsity metrics.",
    )
    parser.add_argument(
        "--compute_variance_metrics",
        action="store_true",
        help="Compute variance metrics.",
    )
    parser.add_argument(
        "--compute_featurewise_density_statistics",
        action="store_true",
        help="Compute featurewise density statistics.",
    )
    parser.add_argument(
        "--compute_featurewise_weight_based_metrics",
        action="store_true",
        help="Compute featurewise weight-based metrics.",
    )
    parser.add_argument(
        "--exclude_special_tokens_from_reconstruction",
        action="store_true",
        help="Exclude special tokens like BOS, EOS, PAD from reconstruction.",
    )
    parser.add_argument(
        "--dataset",
        default="Skylion007/openwebtext",
        help="Dataset to evaluate on, such as 'Skylion007/openwebtext' or 'lighteval/MATH'.",
    )
    parser.add_argument(
        "--context_size",
        type=int,
        default=128,
        help="Context size to evaluate on.",
    )
    parser.add_argument(
        "--output_folder",
        type=str,
        default="eval_results",
        help="Directory to save evaluation results",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Enable verbose output with tqdm loaders.",
    )
    parser.add_argument(
        "--force_rerun", action="store_true", help="Force rerun of experiments"
    )
    parser.add_argument(
        "--llm_dtype",
        type=str,
        default="float32",
        choices=["float32", "float64", "float16", "bfloat16"],
        help="Data type for computation",
    )


if __name__ == "__main__":
    """
    python evals/core/main.py "sae_bench_pythia70m_sweep_standard_ctx128_0712" "blocks.4.hook_resid_post__trainer_10" \
    --batch_size_prompts 16 \
    --n_eval_sparsity_variance_batches 2000 \
    --n_eval_reconstruction_batches 200 \
    --output_folder "eval_results/core" \
    --exclude_special_tokens_from_reconstruction --verbose

    python evals/core/main.py "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109" "blocks.19.hook_resid_post__trainer_2" \
    --batch_size_prompts 16 \
    --n_eval_sparsity_variance_batches 2000 \
    --n_eval_reconstruction_batches 200 \
    --output_folder "eval_results/core" \
    --exclude_special_tokens_from_reconstruction --verbose --llm_dtype bfloat16
    """
    # args = arg_parser().parse_args()
    # eval_results = run_evaluations(args)

    raise NotImplementedError("Not implemented, use run_all_evals.py instead")

    print("Evaluation complete. All results have been saved incrementally.")  # type: ignore
    # print(f"Combined JSON: {output_files['combined_json']}")
    # print(f"CSV: {output_files['csv']}")


# Use this code snippet to use custom SAE objects
# if __name__ == "__main__":
#     import sae_bench.custom_saes.identity_sae as identity_sae
#     import sae_bench.custom_saes.jumprelu_sae as jumprelu_sae

#     start_time = time.time()

#     random_seed = 42
#     output_folder = "eval_results/core"

#     batch_size_prompts = 16
#     n_eval_reconstruction_batches = 20
#     n_eval_sparsity_variance_batches = 20
#     context_size = 128
#     dataset_name = "Skylion007/openwebtext"
#     exclude_special_tokens_from_reconstruction = True

#     model_name = "gemma-2-2b"
#     hook_layer = 20
#     llm_dtype = torch.bfloat16

#     repo_id = "google/gemma-scope-2b-pt-res"
#     filename = f"layer_{hook_layer}/width_16k/average_l0_71/params.npz"
#     sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, hook_layer)
#     selected_saes = [(f"{repo_id}_{filename}_gemmascope_sae", sae)]

#     # it's recommended to specify the dtype of the SAE
#     for sae_name, sae in selected_saes:
#         sae.cfg.dtype = "bfloat16"

#     multiple_evals(
#         filtered_saes=selected_saes,
#         n_eval_reconstruction_batches=n_eval_reconstruction_batches,
#         n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches,
#         eval_batch_size_prompts=batch_size_prompts,
#         exclude_special_tokens_from_reconstruction=exclude_special_tokens_from_reconstruction,
#         dataset=dataset_name,
#         context_size=context_size,
#         output_folder=output_folder,
#         verbose=True,
#         dtype=llm_dtype,
#     )
