import os
from typing import Any, Dict, Optional, Tuple

import requests
import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer


def tokenize_text_with_placeholder(
    tokenizer,
    template,
    variable_context,
    verbose=False,
    skip_special_tokens=True,
    placeholder="{0}",
):

    placeholder_pos = template.find(placeholder)

    if placeholder_pos == -1:
        raise ValueError(f"Placeholder '{placeholder}' not found in template")

    before = template[:placeholder_pos]
    after = template[placeholder_pos + len(placeholder) :]

    start_ids = tokenizer.encode(before, add_special_tokens=not skip_special_tokens)
    middle_ids = tokenizer.encode(variable_context, add_special_tokens=False)
    end_ids = tokenizer.encode(after, add_special_tokens=False)

    start_fixed = [True] * len(start_ids)
    middle_fixed = [False] * len(middle_ids)
    end_fixed = [True] * len(end_ids)

    fixed_positions = start_fixed + middle_fixed + end_fixed

    full_ids = start_ids + middle_ids + end_ids

    if verbose:
        print(f"Full text: {tokenizer.decode(full_ids)}")
    return torch.tensor(full_ids), fixed_positions


def load_sae_saelens(
    release: str = "gemma-scope-9b-it-res-canonical",
    sae_id: str = "layer_9/width_131k/canonical",
    device: str = "cuda",
    dtype: str = "bfloat16",
):
    sae, _, _ = SAE.from_pretrained(release=release, sae_id=sae_id, device=device)

    torch_dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
    sae = sae.to(torch_dtype)
    return sae


def load_model_tlens(
    model_name: str = "google/gemma-2b-it",
    device: str = "cuda",
    dtype: str = "bfloat16",
    load_with_no_processing: bool = False,
):

    if load_with_no_processing:
        model = HookedTransformer.from_pretrained_no_processing(
            model_name, dtype=dtype, device=device
        )
    else:
        model = HookedTransformer.from_pretrained(
            model_name, dtype=dtype, device=device
        )

    return model


class NeuronpediaResponse:
    def __init__(self, data: dict):
        """
        Initialize the parser with the JSON response data.
        :param data: Dictionary returned from the API.
        """
        self.data = data

    # -------------------
    # Activations (list)
    # -------------------
    def get_activations(self):
        return self.data.get("activations", [])

    def get_activation_by_index(self, idx: int):
        activations = self.get_activations()
        if idx < 0 or idx >= len(activations):
            raise IndexError("Activation index out of range.")
        return activations[idx]

    def get_activation_tokens(self, idx: int):
        activation = self.get_activation_by_index(idx)
        return activation.get("tokens", [])

    def get_activation_maxValue(self, idx: int):
        activation = self.get_activation_by_index(idx)
        return activation.get("maxValue")

    def get_activation_values(self, idx: int):
        activation = self.get_activation_by_index(idx)
        return activation.get("values", [])

    # -------------------
    # Explanations (list)
    # -------------------
    def get_explanations(self):
        return self.data.get("explanations", [])

    def get_explanation_by_index(self, idx: int):
        explanations = self.get_explanations()
        if idx < 0 or idx >= len(explanations):
            raise IndexError("Explanation index out of range.")
        return explanations[idx]

    def get_all_explanation_descriptions(self):
        """Returns a list of all explanation descriptions."""
        return [ex.get("description", "") for ex in self.get_explanations()]

    def get_all_explanation_model_names(self):
        """Returns a list of all explanation model names."""
        return [ex.get("explanationModelName", "") for ex in self.get_explanations()]

    def get_explanation_by_model_name(self, model_name: str):
        """Returns the explanation for the given model name."""
        return self.get_all_explanation_descriptions()[
            self.get_all_explanation_model_names().index(model_name)
        ]

    def get_contexts_around_top_n_activations(self, n=3, window=5):
        """
        Returns a list of context strings from the top n activations, where each context string
        is built from a window of tokens around the token with the highest activation value
        in that activation.

        Duplicate activations (by maxValue) are skipped.

        :param n: The number of top activating examples (activations) to return.
        :param window: The number of tokens before and after the max activation token to include.
        :return: List of context strings for the top n activations.
        """
        activations = self.get_activations()
        if not activations:
            return []
        # Sort activations descending by their max activation value.
        sorted_activations = sorted(
            activations, key=lambda act: act.get("maxValue", 0), reverse=True
        )
        contexts = []
        seen_max_values = set()
        # Iterate through sorted activations and skip ones with duplicate maxValue.
        for act in sorted_activations:
            max_val = act.get("maxValue", 0)
            if max_val in seen_max_values:
                continue
            seen_max_values.add(max_val)
            tokens = act.get("tokens", [])
            values = act.get("values", [])
            if not tokens or not values or len(tokens) != len(values):
                context_text = " ".join(tokens)
            else:
                max_index = max(range(len(values)), key=lambda i: values[i])
                start_index = max(0, max_index - window)
                end_index = min(len(tokens), max_index + window + 1)
                context_tokens = tokens[start_index:end_index]
                # Reconstruct the text: assuming the "▁" prefix marks a new word.
                context_text = "".join(context_tokens).replace("▁", " ").strip()
            contexts.append(context_text)
            if len(contexts) == n:
                break
        return contexts


def get_neuronpedia_info(
    neuronpedia_id: str,
    index: int = 321,
    api_key: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Get the auto-interpretation description for a specific feature from Neuronpedia.

    Args:
        neuronpedia_id (str): The neuronpedia ID (e.g., "gemma-2-9b-it/9-gemmascope-res-131k")
        index (int): The index of the feature
        api_key (Optional[str]): Neuronpedia API key, defaults to environment variable

    Returns:
        Dict[str, Any]: The feature data including explanations
    """
    # Get API key from environment if not provided
    if api_key is None:
        api_key = os.environ.get("NEURONPEDIA_API_KEY")
        if not api_key:
            raise ValueError(
                "Please provide an API key or set NEURONPEDIA_API_KEY environment variable"
            )

    # Set up the API request
    url = f"https://www.neuronpedia.org/api/feature/{neuronpedia_id}/{index}"
    headers = {"Authorization": f"Bearer {api_key}"}

    # Make the request
    response = requests.get(url, headers=headers)

    # Handle response
    if response.status_code == 200:
        return NeuronpediaResponse(response.json())
    else:
        raise Exception(
            f"Error fetching feature: {response.status_code} - {response.text}"
        )


def get_saelens_release_and_id(
    neuronpedia_id: str, api_key: str = None
) -> Tuple[str, str]:

    # Set up the API request
    if api_key is None:
        api_key = os.environ.get("NEURONPEDIA_API_KEY")
        if not api_key:
            raise ValueError(
                "Please provide an API key or set NEURONPEDIA_API_KEY environment variable"
            )
    url = f"https://www.neuronpedia.org/api/feature/{neuronpedia_id}/0"
    headers = {"Authorization": f"Bearer {api_key}"}

    # Make the request
    response = requests.get(url, headers=headers)

    response_json = response.json()
    sae_lens_release = response_json["source"]["saelensRelease"]
    saelens_sae_id = response_json["source"]["saelensSaeId"]
    return sae_lens_release, saelens_sae_id
