import torch
from transformer_lens import HookedTransformer

from .base_sae import BaseSAE


class FFKV(BaseSAE):
    """
    A custom SAE wrapper that mirrors a specific FF block instance from a model.
    It treats the intermediate FF neuron activations (after the activation function/gating)
    as SAE features, obtained via hooks.

    This implementation uses TransformerLens hook points exclusively.

    - Input: blocks.{layer}.ln2.hook_normalized - Shape [batch, seq_len, d_model]
      The normalized input to the FF after layer normalization

    - Features: blocks.{layer}.mlp.hook_post - Shape [batch, seq_len, d_mlp]
      The intermediate activations after gating and activation functions are applied

    - Output: blocks.{layer}.hook_mlp_out - Shape [batch, seq_len, d_model]
      The final output of the FF that gets added to the residual stream
    """

    def __init__(
        self,
        model: HookedTransformer,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        hook_points: dict | None = None,
    ):
        # Use provided hook_points or default
        if hook_points is None:
            hook_points = {
                "input": f"blocks.{hook_layer}.ln2.hook_normalized",
                "features": f"blocks.{hook_layer}.mlp.hook_post",
                "output": f"blocks.{hook_layer}.hook_mlp_out",
            }

        self.input_hook = hook_points["input"]
        self.features_hook = hook_points["features"]
        self.output_hook = hook_points["output"]

        # Get dimensions from model configuration
        d_in = model.cfg.d_model
        d_sae = model.cfg.d_mlp

        # Initialize BaseSAE
        super().__init__(
            d_in=d_in,
            d_sae=d_sae,
            model_name=model.cfg.model_name,
            hook_layer=hook_layer,
            device=device,
            dtype=dtype,
            hook_name=self.input_hook,  # Input hook
        )

        # Store reference to model for hook-based operations
        self.model = model
        self.to(device=device, dtype=dtype)
        self.W_enc = model.blocks[hook_layer].mlp.W_in  # type: ignore
        self.W_dec = model.blocks[hook_layer].mlp.W_out  # type: ignore
        self.b_dec = model.blocks[hook_layer].mlp.b_out  # type: ignore

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Computes the intermediate FF activations by running the input through the model
        and capturing the post-activation features using hooks.

        Args:
            x: Input tensor corresponding to ln2.hook_normalized, shape (..., d_in)

        Returns:
            Intermediate activations, shape (..., d_sae)
        """
        # Store original input for later use
        x_orig = x.to(device=self.device, dtype=self.dtype)

        # Capture the post-activation features from the hook
        captured_features = []

        def hook_fn(act, hook):
            captured_features.append(act.clone())
            return act

        # Run the forward pass with input injection and feature capture
        def input_hook_fn(_, hook):
            # Convert to the model's native dtype to ensure dtype match with model weights
            return x_orig.to(device=self.model.cfg.device, dtype=self.model.cfg.dtype)

        # Use hooks to inject the input and capture the features
        self.model.run_with_hooks(
            torch.zeros(
                (1, 1), device=self.device, dtype=torch.long
            ),  # Dummy input with correct dtype
            fwd_hooks=[
                (self.input_hook, input_hook_fn),
                (self.features_hook, hook_fn),
            ],
            stop_at_layer=self.cfg.hook_layer + 1,
        )

        if not captured_features:
            raise RuntimeError(f"Hook {self.features_hook} did not capture features")

        # Convert back to our dtype if needed
        return captured_features[0].to(dtype=self.dtype)

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        'Decodes' the intermediate FF activations to produce the FF output.

        Args:
            feature_acts: Intermediate activations, shape (..., d_sae)

        Returns:
            Output tensor, shape (..., d_in)
        """
        # Store original input for later use
        feature_acts_orig = feature_acts.to(device=self.device, dtype=self.dtype)

        # Capture the ff output
        captured_output = []

        def hook_fn(act, hook):
            captured_output.append(act.clone())
            return act

        # Inject features and capture the output
        def feature_hook_fn(_, hook):
            # Convert to the model's native dtype to ensure dtype match with model weights
            return feature_acts_orig.to(
                device=self.model.cfg.device, dtype=self.model.cfg.dtype
            )

        # Use hooks to inject the features and capture the output
        self.model.run_with_hooks(
            torch.zeros(
                (1, 1), device=self.device, dtype=torch.long
            ),  # Dummy input with correct dtype
            fwd_hooks=[
                (self.features_hook, feature_hook_fn),
                (self.output_hook, hook_fn),
            ],
            stop_at_layer=self.cfg.hook_layer + 1,
        )

        if not captured_output:
            raise RuntimeError(f"Hook {self.output_hook} did not capture output")

        # Convert back to our dtype if needed
        return captured_output[0].to(dtype=self.dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Performs the full FF block pass by using the model's hooks directly.

        Args:
            x: Input tensor, shape (..., d_in)

        Returns:
            Output tensor, shape (..., d_in)
        """
        # Store original input for later use
        x_orig = x.to(device=self.device, dtype=self.dtype)

        # Capture the ff output
        captured_output = []

        def hook_fn(act, hook):
            captured_output.append(act.clone())
            return act

        # Inject input and capture the output
        def input_hook_fn(_, hook):
            # Convert to the model's native dtype to ensure dtype match with model weights
            return x_orig.to(device=self.model.cfg.device, dtype=self.model.cfg.dtype)

        # Use hooks to inject the input and capture the output
        self.model.run_with_hooks(
            torch.zeros(
                (1, 1), device=self.device, dtype=torch.long
            ),  # Dummy input with correct dtype
            fwd_hooks=[(self.input_hook, input_hook_fn), (self.output_hook, hook_fn)],
            stop_at_layer=self.cfg.hook_layer + 1,
        )

        if not captured_output:
            raise RuntimeError(f"Hook {self.output_hook} did not capture output")

        # Convert back to our dtype if needed
        return captured_output[0].to(dtype=self.dtype)

    @torch.no_grad()
    def check_decoder_norms(self) -> bool:
        """Decoder norm check is not applicable for this SAE."""
        print("⚠️ Decoder norm check is not applicable for FFKV.")
        return True

    @classmethod
    def from_model(
        cls,
        model: HookedTransformer,
        hook_layer: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
        hook_points: dict | None = None,
    ):
        """Factory method to create a FFKV from a model."""
        if device is None:
            device = model.cfg.device  # type: ignore
        if dtype is None:
            dtype = model.cfg.dtype

        return cls(
            model=model,
            hook_layer=hook_layer,
            device=device,  # type: ignore
            dtype=dtype,
            hook_points=hook_points,
        )
