import os
from typing import Any, Dict, Optional

import requests


class NeuronpediaResponse:
    def __init__(self, data: dict):
        """
        Initialize the parser with the JSON response data.
        :param data: Dictionary returned from the API.
        """
        self.data = data

    # -------------------
    # Top-level keys
    # -------------------
    def get_modelId(self):
        return self.data.get("modelId")

    def get_layer(self):
        return self.data.get("layer")

    def get_index(self):
        return self.data.get("index")

    def get_sourceSetName(self):
        return self.data.get("sourceSetName")

    def get_creatorId(self):
        return self.data.get("creatorId")

    def get_createdAt(self):
        return self.data.get("createdAt")

    def get_maxActApprox(self):
        return self.data.get("maxActApprox")

    def get_hasVector(self):
        return self.data.get("hasVector")

    def get_vector(self):
        return self.data.get("vector")

    def get_vectorLabel(self):
        return self.data.get("vectorLabel")

    def get_vectorDefaultSteerStrength(self):
        return self.data.get("vectorDefaultSteerStrength")

    def get_hookName(self):
        return self.data.get("hookName")

    def get_topkCosSimIndices(self):
        return self.data.get("topkCosSimIndices")

    def get_topkCosSimValues(self):
        return self.data.get("topkCosSimValues")

    def get_neuron_alignment_indices(self):
        return self.data.get("neuron_alignment_indices")

    def get_neuron_alignment_values(self):
        return self.data.get("neuron_alignment_values")

    def get_neuron_alignment_l1(self):
        return self.data.get("neuron_alignment_l1")

    def get_correlated_neurons_indices(self):
        return self.data.get("correlated_neurons_indices")

    def get_correlated_neurons_pearson(self):
        return self.data.get("correlated_neurons_pearson")

    def get_correlated_neurons_l1(self):
        return self.data.get("correlated_neurons_l1")

    def get_correlated_features_indices(self):
        return self.data.get("correlated_features_indices")

    def get_correlated_features_pearson(self):
        return self.data.get("correlated_features_pearson")

    def get_correlated_features_l1(self):
        return self.data.get("correlated_features_l1")

    def get_neg_str(self):
        return self.data.get("neg_str")

    def get_neg_values(self):
        return self.data.get("neg_values")

    def get_pos_str(self):
        return self.data.get("pos_str")

    def get_pos_values(self):
        return self.data.get("pos_values")

    def get_frac_nonzero(self):
        return self.data.get("frac_nonzero")

    def get_freq_hist_data_bar_heights(self):
        return self.data.get("freq_hist_data_bar_heights")

    def get_freq_hist_data_bar_values(self):
        return self.data.get("freq_hist_data_bar_values")

    def get_logits_hist_data_bar_heights(self):
        return self.data.get("logits_hist_data_bar_heights")

    def get_logits_hist_data_bar_values(self):
        return self.data.get("logits_hist_data_bar_values")

    def get_decoder_weights_dist(self):
        return self.data.get("decoder_weights_dist")

    def get_umap_cluster(self):
        return self.data.get("umap_cluster")

    def get_umap_log_feature_sparsity(self):
        return self.data.get("umap_log_feature_sparsity")

    def get_umap_x(self):
        return self.data.get("umap_x")

    def get_umap_y(self):
        return self.data.get("umap_y")

    # -------------------
    # Model information (nested under "model")
    # -------------------
    def get_model(self):
        return self.data.get("model", {})

    def get_model_id(self):
        return self.get_model().get("id")

    def get_model_displayNameShort(self):
        return self.get_model().get("displayNameShort")

    def get_model_displayName(self):
        return self.get_model().get("displayName")

    def get_model_creatorId(self):
        return self.get_model().get("creatorId")

    def get_model_tlensId(self):
        return self.get_model().get("tlensId")

    def get_model_dimension(self):
        return self.get_model().get("dimension")

    def get_model_thinking(self):
        return self.get_model().get("thinking")

    def get_model_visibility(self):
        return self.get_model().get("visibility")

    def get_model_defaultSourceSetName(self):
        return self.get_model().get("defaultSourceSetName")

    def get_model_defaultSourceId(self):
        return self.get_model().get("defaultSourceId")

    def get_model_inferenceEnabled(self):
        return self.get_model().get("inferenceEnabled")

    def get_model_instruct(self):
        return self.get_model().get("instruct")

    def get_model_layers(self):
        return self.get_model().get("layers")

    def get_model_neuronsPerLayer(self):
        return self.get_model().get("neuronsPerLayer")

    def get_model_createdAt(self):
        return self.get_model().get("createdAt")

    def get_model_owner(self):
        return self.get_model().get("owner")

    def get_model_updatedAt(self):
        return self.get_model().get("updatedAt")

    def get_model_website(self):
        return self.get_model().get("website")

    # -------------------
    # Creator information (nested under "creator")
    # -------------------
    def get_creator(self):
        return self.data.get("creator", {})

    def get_creator_name(self):
        return self.get_creator().get("name")

    # -------------------
    # Source information (nested under "source")
    # -------------------
    def get_source(self):
        return self.data.get("source", {})

    def get_source_id(self):
        return self.get_source().get("id")

    def get_source_modelId(self):
        return self.get_source().get("modelId")

    def get_source_hasDashboards(self):
        return self.get_source().get("hasDashboards")

    def get_source_inferenceEnabled(self):
        return self.get_source().get("inferenceEnabled")

    def get_source_saelensConfig(self):
        return self.get_source().get("saelensConfig", {})

    def get_source_saelensRelease(self):
        return self.get_source().get("saelensRelease")

    def get_source_saelensSaeId(self):
        return self.get_source().get("saelensSaeId")

    def get_source_hfRepoId(self):
        return self.get_source().get("hfRepoId")

    def get_source_hfFolderId(self):
        return self.get_source().get("hfFolderId")

    def get_source_visibility(self):
        return self.get_source().get("visibility")

    def get_source_defaultOfModelId(self):
        return self.get_source().get("defaultOfModelId")

    def get_source_setName(self):
        return self.get_source().get("setName")

    def get_source_creatorId(self):
        return self.get_source().get("creatorId")

    def get_source_hasUmap(self):
        return self.get_source().get("hasUmap")

    def get_source_hasUmapLogSparsity(self):
        return self.get_source().get("hasUmapLogSparsity")

    def get_source_hasUmapClusters(self):
        return self.get_source().get("hasUmapClusters")

    def get_source_num_prompts(self):
        return self.get_source().get("num_prompts")

    def get_source_num_tokens_in_prompt(self):
        return self.get_source().get("num_tokens_in_prompt")

    def get_source_dataset(self):
        return self.get_source().get("dataset")

    def get_source_notes(self):
        return self.get_source().get("notes")

    def get_source_cosSimMatchModelId(self):
        return self.get_source().get("cosSimMatchModelId")

    def get_source_cosSimMatchSourceId(self):
        return self.get_source().get("cosSimMatchSourceId")

    def get_source_createdAt(self):
        return self.get_source().get("createdAt")

    # -------------------
    # SourceSet information (nested under "sourceSet")
    # -------------------
    def get_sourceSet(self):
        return self.data.get("sourceSet", {})

    def get_sourceSet_modelId(self):
        return self.get_sourceSet().get("modelId")

    def get_sourceSet_name(self):
        return self.get_sourceSet().get("name")

    def get_sourceSet_hasDashboards(self):
        return self.get_sourceSet().get("hasDashboards")

    def get_sourceSet_allowInferenceSearch(self):
        return self.get_sourceSet().get("allowInferenceSearch")

    def get_sourceSet_visibility(self):
        return self.get_sourceSet().get("visibility")

    def get_sourceSet_description(self):
        return self.get_sourceSet().get("description")

    def get_sourceSet_type(self):
        return self.get_sourceSet().get("type")

    def get_sourceSet_creatorName(self):
        return self.get_sourceSet().get("creatorName")

    def get_sourceSet_urls(self):
        return self.get_sourceSet().get("urls")

    def get_sourceSet_creatorEmail(self):
        return self.get_sourceSet().get("creatorEmail")

    def get_sourceSet_creatorId(self):
        return self.get_sourceSet().get("creatorId")

    def get_sourceSet_releaseName(self):
        return self.get_sourceSet().get("releaseName")

    def get_sourceSet_defaultOfModelId(self):
        return self.get_sourceSet().get("defaultOfModelId")

    def get_sourceSet_defaultRange(self):
        return self.get_sourceSet().get("defaultRange")

    def get_sourceSet_defaultShowBreaks(self):
        return self.get_sourceSet().get("defaultShowBreaks")

    def get_sourceSet_showDfa(self):
        return self.get_sourceSet().get("showDfa")

    def get_sourceSet_showCorrelated(self):
        return self.get_sourceSet().get("showCorrelated")

    def get_sourceSet_showHeadAttribution(self):
        return self.get_sourceSet().get("showHeadAttribution")

    def get_sourceSet_showUmap(self):
        return self.get_sourceSet().get("showUmap")

    def get_sourceSet_createdAt(self):
        return self.get_sourceSet().get("createdAt")

    # -------------------
    # Comments (list)
    # -------------------
    def get_comments(self):
        return self.data.get("comments", [])

    # -------------------
    # Activations (list)
    # -------------------
    def get_activations(self):
        return self.data.get("activations", [])

    def get_activation_by_index(self, idx: int):
        activations = self.get_activations()
        if idx < 0 or idx >= len(activations):
            raise IndexError("Activation index out of range.")
        return activations[idx]

    def get_activation_tokens(self, idx: int):
        activation = self.get_activation_by_index(idx)
        return activation.get("tokens", [])

    def get_activation_maxValue(self, idx: int):
        activation = self.get_activation_by_index(idx)
        return activation.get("maxValue")

    def get_activation_values(self, idx: int):
        activation = self.get_activation_by_index(idx)
        return activation.get("values", [])

    # -------------------
    # Explanations (list)
    # -------------------
    def get_explanations(self):
        return self.data.get("explanations", [])

    def get_explanation_by_index(self, idx: int):
        explanations = self.get_explanations()
        if idx < 0 or idx >= len(explanations):
            raise IndexError("Explanation index out of range.")
        return explanations[idx]

    def get_all_explanation_descriptions(self):
        """Returns a list of all explanation descriptions."""
        return [ex.get("description", "") for ex in self.get_explanations()]

    def get_all_explanation_model_names(self):
        """Returns a list of all explanation model names."""
        return [ex.get("explanationModelName", "") for ex in self.get_explanations()]

    def get_explanation_by_model_name(self, model_name: str):
        """Returns the explanation for the given model name."""
        return self.get_all_explanation_descriptions()[
            self.get_all_explanation_model_names().index(model_name)
        ]

    def get_contexts_around_top_n_activations(self, n=3, window=5):
        """
        Returns a list of context strings from the top n activations, where each context string
        is built from a window of tokens around the token with the highest activation value
        in that activation.

        Duplicate activations (by maxValue) are skipped.

        :param n: The number of top activating examples (activations) to return.
        :param window: The number of tokens before and after the max activation token to include.
        :return: List of context strings for the top n activations.
        """
        activations = self.get_activations()
        if not activations:
            return []
        # Sort activations descending by their max activation value.
        sorted_activations = sorted(
            activations, key=lambda act: act.get("maxValue", 0), reverse=True
        )
        contexts = []
        seen_max_values = set()
        # Iterate through sorted activations and skip ones with duplicate maxValue.
        for act in sorted_activations:
            max_val = act.get("maxValue", 0)
            if max_val in seen_max_values:
                continue
            seen_max_values.add(max_val)
            tokens = act.get("tokens", [])
            values = act.get("values", [])
            if not tokens or not values or len(tokens) != len(values):
                context_text = " ".join(tokens)
            else:
                max_index = max(range(len(values)), key=lambda i: values[i])
                start_index = max(0, max_index - window)
                end_index = min(len(tokens), max_index + window + 1)
                context_tokens = tokens[start_index:end_index]
                # Reconstruct the text: assuming the "▁" prefix marks a new word.
                context_text = "".join(context_tokens).replace("▁", " ").strip()
            contexts.append(context_text)
            if len(contexts) == n:
                break
        return contexts

    # And so on…


def get_feature_url(neuronpedia_id: str, index: int = 321):
    return f"https://www.neuronpedia.org/{neuronpedia_id}/{index}"


def get_feature_description(
    neuronpedia_id: str,
    index: int = 321,
    api_key: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Get the auto-interpretation description for a specific feature from Neuronpedia.

    Args:
        neuronpedia_id (str): The neuronpedia ID (e.g., "gemma-2-9b-it/9-gemmascope-res-131k")
        index (int): The index of the feature
        api_key (Optional[str]): Neuronpedia API key, defaults to environment variable

    Returns:
        Dict[str, Any]: The feature data including explanations
    """
    # Get API key from environment if not provided
    if api_key is None:
        api_key = os.environ.get("NEURONPEDIA_API_KEY")
        if not api_key:
            raise ValueError(
                "Please provide an API key or set NEURONPEDIA_API_KEY environment variable"
            )

    # Set up the API request
    url = f"https://www.neuronpedia.org/api/feature/{neuronpedia_id}/{index}"
    headers = {"Authorization": f"Bearer {api_key}"}

    # Make the request
    response = requests.get(url, headers=headers)

    # Handle response
    if response.status_code == 200:
        return NeuronpediaResponse(response.json())
    else:
        raise Exception(
            f"Error fetching feature: {response.status_code} - {response.text}"
        )


if __name__ == "__main__":
    response = get_feature_description(
        neuronpedia_id="gemma-2-9b-it/9-gemmascope-res-131k", index=321
    )
    response = NeuronpediaResponse(response)
    print(response.get_contexts_around_top_n_activations())
