import h5py


def _filter_out_nosiy_activation(imagenet_mean_acts, features, noisy_threshold=0.1):
    noisy_features_indices = (imagenet_mean_acts > noisy_threshold).nonzero()[0].tolist()
    if len(features.shape) == 1:
        features[noisy_features_indices] = 0
    elif len(features.shape) == 2:
        features[:, noisy_features_indices] = 0
    elif len(features.shape) == 3:
        features[:, :, noisy_features_indices] = 0
    return features


def load_activation_data(root, dataset_name, split, latent_indices=None):
    """Load activation data from HDF5 file."""

    with h5py.File(f"{root}/data/imagenet_analysis/train_sae_stats.h5") as hf:
        mean_acts = hf["sae_mean_acts"][:]
        if mean_acts.max() > 1:
            sparsity = hf["sae_sparsity"][:]
            mean_acts /= sparsity

    data_path = f"{root}/data/{dataset_name}_analysis/{split}_sae_latents.h5"
    if latent_indices is not None:
        with h5py.File(data_path, "r") as hf:
            activations = hf["activations"][:, latent_indices]
    else:
        with h5py.File(data_path, "r") as hf:
            activations = hf["activations"][:]
        activations = _filter_out_nosiy_activation(mean_acts, activations)
    return activations


def filter_out_nosiy_activation(root, features):
    with h5py.File(f"{root}/data/imagenet_analysis/train_sae_stats.h5") as hf:
        mean_acts = hf["sae_mean_acts"][:]
        if mean_acts.max() > 1:
            sparsity = hf["sae_sparsity"][:]
            mean_acts /= sparsity
    noisy_features_indices = (mean_acts > 0.1).nonzero()[0].tolist()
    features[:, noisy_features_indices] = 0
    return features
