import torch
from transformer_lens import HookedTransformer

from .base_sae import BaseSAE


class KSparseFFKVNormalized(BaseSAE):
    """
    A custom SAE wrapper that mirrors a specific FF block instance from a model.
    It treats the intermediate FF neuron activations (after the activation function/gating)
    as SAE features, obtained via hooks.

    This implementation uses TransformerLens hook points exclusively.

    - Input: blocks.{layer}.ln2.hook_normalized - Shape [batch, seq_len, d_model]
      The normalized input to the FF after layer normalization

    - Features: blocks.{layer}.mlp.hook_post - Shape [batch, seq_len, d_mlp]
      The intermediate activations after gating and activation functions are applied

    - Output: blocks.{layer}.hook_mlp_out - Shape [batch, seq_len, d_model]
      The final output of the FF that gets added to the residual stream
    """

    def __init__(
        self,
        model: HookedTransformer,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        k: int,
        hook_points: dict | None = None,
    ):
        # Use provided hook_points or default
        if hook_points is None:
            hook_points = {
                "input": f"blocks.{hook_layer}.ln2.hook_normalized",
                "features": f"blocks.{hook_layer}.mlp.hook_post",
                "output": f"blocks.{hook_layer}.hook_mlp_out",
            }

        self.input_hook = hook_points["input"]
        self.features_hook = hook_points["features"]
        self.output_hook = hook_points["output"]
        self.k = k  # Number of top activations to keep

        # Get dimensions from model configuration
        d_in = model.cfg.d_model
        d_sae = model.cfg.d_mlp

        # Initialize BaseSAE
        super().__init__(
            d_in=d_in,
            d_sae=d_sae,
            model_name=model.cfg.model_name,
            hook_layer=hook_layer,
            device=device,
            dtype=dtype,
            hook_name=self.input_hook,  # Input hook
        )

        # Store reference to model for hook-based operations
        self.model = model
        self.to(device=device, dtype=dtype)

        # Get weights directly from model
        W_in = model.blocks[hook_layer].mlp.W_in  # type: ignore
        W_out = model.blocks[hook_layer].mlp.W_out  # type: ignore
        b_out = model.blocks[hook_layer].mlp.b_out  # type: ignore

        # Apply normalization to decoder weights
        normalized_W_out, self.norms = self._normalize_decoder_weights(W_out)  # type: ignore

        # Assign to the .data attribute of the parameters
        self.W_enc.data = W_in.clone()  # type: ignore
        self.W_dec.data = normalized_W_out.clone()
        self.b_dec.data = b_out.clone()  # type: ignore

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Computes the intermediate FF activations by running the input through the model
        and capturing the post-activation features using hooks.

        Args:
            x: Input tensor corresponding to ln2.hook_normalized, shape (..., d_in)

        Returns:
            Intermediate activations, shape (..., d_sae)
        """
        # Store original input for later use
        x_orig = x.to(device=self.device, dtype=self.dtype)

        # Capture the post-activation features from the hook
        captured_features = []

        def hook_fn(act, hook):
            captured_features.append(act.clone())
            return act

        # Run the forward pass with input injection and feature capture
        def input_hook_fn(_, hook):
            # Convert to the model's native dtype to ensure dtype match with model weights
            return x_orig.to(device=self.model.cfg.device, dtype=self.model.cfg.dtype)

        # Use hooks to inject the input and capture the features
        self.model.run_with_hooks(
            torch.zeros(
                (1, 1), device=self.device, dtype=torch.long
            ),  # Dummy input with correct dtype
            fwd_hooks=[(self.input_hook, input_hook_fn), (self.features_hook, hook_fn)],
            stop_at_layer=self.cfg.hook_layer + 1,
        )

        if not captured_features:
            raise RuntimeError(f"Hook {self.features_hook} did not capture features")

        # Get the features
        features = captured_features[0].to(dtype=self.dtype)

        # For each position, keep only the top-k features
        orig_shape = features.shape
        # Reshape to merge batch dimensions for topk operation
        flat_features = features.reshape(-1, features.shape[-1])

        # Create a mask of zeros
        mask = torch.zeros_like(flat_features, dtype=torch.bool)

        # For each sample, find the indices of the top k values
        _, indices = torch.topk(
            flat_features.abs(), k=min(self.k, flat_features.shape[-1]), dim=-1
        )

        # Set the mask to True at these indices for each sample
        batch_indices = torch.arange(
            flat_features.shape[0], device=self.device
        ).unsqueeze(1)
        mask[batch_indices, indices] = True

        # Apply the mask - zero out non-top-k values
        sparse_features = torch.zeros_like(flat_features)
        sparse_features[mask] = flat_features[mask]

        # Reshape back to original dimensions
        features = sparse_features.reshape(orig_shape)

        return features

    def hook_based_decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        'Decodes' the intermediate FF activations to produce the FF output.

        Args:
            feature_acts: Intermediate activations, shape (..., d_sae)

        Returns:
            Output tensor, shape (..., d_in)
        """
        # Store original input for later use
        feature_acts_orig = feature_acts.to(device=self.device, dtype=self.dtype)

        # Capture the ff output
        captured_output = []

        def hook_fn(act, hook):
            captured_output.append(act.clone())
            return act

        # Inject features and capture the output
        def feature_hook_fn(_, hook):
            # Convert to the model's native dtype to ensure dtype match with model weights
            return feature_acts_orig.to(
                device=self.model.cfg.device, dtype=self.model.cfg.dtype
            )

        # Use hooks to inject the features and capture the output
        self.model.run_with_hooks(
            torch.zeros(
                (1, 1), device=self.device, dtype=torch.long
            ),  # Dummy input with correct dtype
            fwd_hooks=[
                (self.features_hook, feature_hook_fn),
                (self.output_hook, hook_fn),
            ],
            stop_at_layer=self.cfg.hook_layer + 1,
        )

        if not captured_output:
            raise RuntimeError(f"Hook {self.output_hook} did not capture output")

        # Convert back to our dtype if needed
        return captured_output[0].to(dtype=self.dtype)

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """Use the actual down_proj module with post-normalization."""
        # Convert activations to model's native format
        feature_acts_native = feature_acts.to(
            device=self.model.cfg.device, dtype=self.model.cfg.dtype
        )

        scaled_features = feature_acts_native * self.norms.unsqueeze(0).unsqueeze(0)

        # Apply down-projection
        pre_norm_output = torch.matmul(scaled_features, self.W_dec) + self.b_dec

        # Apply post-normalization (specific to Gemma-2)
        if hasattr(self.model.blocks[self.cfg.hook_layer], "ln2_post"):
            # Apply ln2_post
            ln2_post = self.model.blocks[self.cfg.hook_layer].ln2_post
            output_model_native = ln2_post(pre_norm_output)  # type: ignore
        else:
            output_model_native = pre_norm_output

        return output_model_native.to(dtype=self.dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Performs the full FF block pass by using the model's hooks directly.

        Args:
            x: Input tensor, shape (..., d_in)

        Returns:
            Output tensor, shape (..., d_in)
        """
        # Store original input for later use
        x_orig = x.to(device=self.device, dtype=self.dtype)

        # Capture the ff output
        captured_output = []

        def hook_fn(act, hook):
            captured_output.append(act.clone())
            return act

        # Inject input and capture the output
        def input_hook_fn(_, hook):
            # Convert to the model's native dtype to ensure dtype match with model weights
            return x_orig.to(device=self.model.cfg.device, dtype=self.model.cfg.dtype)

        # Use hooks to inject the input and capture the output
        self.model.run_with_hooks(
            torch.zeros(
                (1, 1), device=self.device, dtype=torch.long
            ),  # Dummy input with correct dtype
            fwd_hooks=[(self.input_hook, input_hook_fn), (self.output_hook, hook_fn)],
            stop_at_layer=self.cfg.hook_layer + 1,
        )

        if not captured_output:
            raise RuntimeError(f"Hook {self.output_hook} did not capture output")

        # Convert back to our dtype if needed
        return captured_output[0].to(dtype=self.dtype)

    @torch.no_grad()
    def check_decoder_norms(self) -> bool:
        """Decoder norm check is not applicable for this SAE."""
        print("⚠️Decoder norm check is not applicable for KSparseFFKVNormalized.")
        return True

    @torch.no_grad()
    def _normalize_decoder_weights(self, W_dec: torch.Tensor):
        """
        Normalize the decoder weights to have unit norm.
        """
        # Store original norms BEFORE normalization
        original_norms = torch.norm(W_dec, dim=1).to(
            dtype=self.dtype, device=self.device
        )

        # Normalize
        normalized_W_dec = torch.nn.functional.normalize(W_dec, dim=1)

        # Verification check can remain
        norms = torch.norm(normalized_W_dec, dim=1)
        tolerance = 1e-2 if W_dec.dtype in [torch.bfloat16, torch.float16] else 1e-5
        assert torch.allclose(norms, torch.ones_like(norms), atol=tolerance)

        return normalized_W_dec, original_norms

    @torch.no_grad()
    def test_sae(self, model_name: str | None = None):
        pass

    @classmethod
    def from_model(
        cls,
        model: HookedTransformer,
        hook_layer: int,
        k: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
        hook_points: dict | None = None,
    ):
        """Factory method to create a KSparseFFKVNormalized from a model."""
        if device is None:
            device = model.cfg.device  # type: ignore
        if dtype is None:
            dtype = model.cfg.dtype

        return cls(
            model=model,
            hook_layer=hook_layer,
            device=device,  # type: ignore
            dtype=dtype,
            hook_points=hook_points,
            k=k,
        )
