from typing import Any, Dict, List, Tuple

import torch
import torch.nn.functional as F
from custom_dreamy.i_runner import IRunner
from sae_lens import SAE
from torch import Tensor
from transformer_lens import HookedTransformer
from transformers import PreTrainedTokenizerBase


class TlensTokenRunner(IRunner):
    def __init__(
        self,
        model: HookedTransformer,
        tokenizer: PreTrainedTokenizerBase,
        token_pos: int,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.token_pos = token_pos

    def run_with_embeddings(
        self, input_embeddings: Tensor
    ) -> Tuple[
        Tensor,
        Tensor,
        Dict[str, Any],
    ]:
        """
        Run the model with embedded inputs and return the target value and logits.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {}) where target is the logit difference
        """
        model_logits = self.model(
            input_embeddings,
            start_at_layer=0,  # needed to skip embedding layer
            return_type="logits",
        )

        # Get logits for the final position
        final_logits = model_logits[:, -1, :]
        target_logit = final_logits[:, self.token_pos]

        # Get the maximum logit excluding the target token
        mask = torch.ones_like(final_logits)
        mask[:, self.token_pos] = float("-inf")
        max_other_logit = torch.max(final_logits * mask, dim=-1).values

        # Calculate logit difference (target - max_other)
        target = target_logit - max_other_logit

        return target, model_logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


class TlensTokenDiffRunner(IRunner):
    def __init__(
        self,
        model: HookedTransformer,
        tokenizer: PreTrainedTokenizerBase,
        token_pos_a: int,
        token_pos_b: int,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.token_pos_a = token_pos_a
        self.token_pos_b = token_pos_b

        self.intermediate_layer = 3*self.model.cfg.n_layers // 4

    def run_with_embeddings(
        self, input_embeddings: Tensor
    ) -> Tuple[
        Tensor,
        Tensor,
        Dict[str, Any],
    ]:
        """
        Run the model with embedded inputs and return the target value and logits.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {}) where target is the logit difference between token_a and token_b
        """
        hook_results = {}

        def save_intermediate(resid, hook):
            hook_results["intermediate_residual"] = resid.detach()
            return resid

        hook_name = f"blocks.{self.intermediate_layer}.hook_resid_post"

        with self.model.hooks(fwd_hooks=[(hook_name, save_intermediate)]):
            model_logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        # Get logits for the final position
        final_logits = model_logits[:, -1, :]
        token_a_logit = final_logits[:, self.token_pos_a]
        token_b_logit = final_logits[:, self.token_pos_b]

        # Calculate logit difference (token_a - token_b)
        target = token_a_logit - token_b_logit

        return target, model_logits, {"intermediate_residual": hook_results.get("intermediate_residual")}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


class SAERunnerHuggingface(IRunner):
    def __init__(self, model: any, sae: SAE, feature: int):
        """
        Initialize an SAERunner for standard SAE configurations.

        Args:
            model: The model to run
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.feature = feature

        assert "hook_resid_post" in sae.cfg.hook_name, (
            "Assuming residual stream, see SAERunnerTLens which supports other hooks"
        )
        assert sae.cfg.architecture == "standard", "todo: support other architectures"
        print(sae.cfg.hook_name)

        self.layer = sae.cfg.hook_layer

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(
            module: torch.nn.Module,
            input: tuple[torch.Tensor, ...],
            output: torch.Tensor,
        ) -> None:
            # Take mean across all tokens instead of just the last one

            # inp_acts = output[0][:, -1, :]
            inp_acts = output[0]
            sae_in = self.sae.process_sae_in(inp_acts)

            # "... d_in, d_in d_sae -> ... d_sae",
            hidden_pre = sae_in @ self.sae.W_enc + self.sae.b_enc
            hidden_pre = hidden_pre.mean(dim=1)  # Average across sequence length

            out["target"] = hidden_pre[:, self.feature]

        with add_fwd_hooks(
            [
                (
                    self.model.model.layers[self.layer + 1],
                    get_target,
                ),  # Hook on the layer directly
            ]
        ):
            logits = self.model(inputs_embeds=input_embeddings).logits

        return out["target"], logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        return transfomer_embed_one_hot_input(one_hot, self.model)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.get_input_embeddings()(int_ids)


class SAERunnerTLens(IRunner):
    def __init__(self, model: HookedTransformer, sae: SAE, feature: int):
        """
        Initialize an SAERunner for TransformerLens models.

        Args:
            model: The TransformerLens model
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.feature = feature
        self.hook_name = sae.cfg.hook_name

        assert not model.cfg.post_embedding_ln, (
            "post embedding ln currently not supported"
        )

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(resid, hook):
            # Take mean across all tokens instead of just the last one
            inp_acts = resid

            pre_acts = []

            def store_pre_acts(acts: torch.Tensor, hook: str):
                pre_acts.append(acts)
                return acts

            self.sae.run_with_hooks(
                inp_acts,
                fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
            )
            mean_pre_acts = pre_acts[0][:, :, self.feature]
            out["target"] = mean_pre_acts.mean(dim=1)
            out["per_token_target"] = mean_pre_acts
            return resid

        # Use TransformerLens hook system
        with self.model.hooks(fwd_hooks=[(self.hook_name, get_target)]):
            logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        return out["target"], logits, {"per_token_target": out["per_token_target"]}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


class L1SAERunnerTLens(IRunner):
    def __init__(self, model: HookedTransformer, sae: SAE, features: List[int]):
        """
        Initialize an SAERunner for TransformerLens models.

        Args:
            model: The TransformerLens model
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.features = features
        assert len(features) == 2, (
            "this is from the blog and only supports two features"
        )
        self.hook_name = sae.cfg.hook_name

        assert not model.cfg.post_embedding_ln, (
            "post embedding ln currently not supported"
        )

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(resid, hook):
            # Take mean across all tokens instead of just the last one
            inp_acts = resid

            pre_acts = []

            def store_pre_acts(acts: torch.Tensor, hook: str):
                pre_acts.append(acts)
                return acts

            self.sae.run_with_hooks(
                inp_acts,
                fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
            )
            mean_pre_acts = pre_acts[0][:, :, self.features]

            mean_pre_acts = F.relu(mean_pre_acts + 10.0)

            # Apply L0.1 norm to the features
            l01_norm = torch.sum(torch.abs(mean_pre_acts) ** 0.1, dim=-1)
        #    l01_norm = l01_norm ** 10
            out["target"] = l01_norm.mean(dim=1)

            out["per_token_target"] = mean_pre_acts
            return resid

        # Use TransformerLens hook system
        with self.model.hooks(fwd_hooks=[(self.hook_name, get_target)]):
            logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        return out["target"], logits, {"per_token_target": out["per_token_target"]}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


class MaxSAERunnerTLens(IRunner):
    def __init__(self, model: HookedTransformer, sae: SAE, features: List[int]):
        """
        Initialize an SAERunner for TransformerLens models.

        Args:
            model: The TransformerLens model
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.features = features

        self.hook_name = sae.cfg.hook_name

        assert not model.cfg.post_embedding_ln, (
            "post embedding ln currently not supported"
        )

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(resid, hook):
            # Take mean across all tokens instead of just the last one
            inp_acts = resid

            pre_acts = []

            def store_pre_acts(acts: torch.Tensor, hook: str):
                pre_acts.append(acts)
                return acts

            self.sae.run_with_hooks(
                inp_acts,
                fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
            )
            pre_acts = pre_acts[0][:, :, self.features]

            # Apply softmax along seq length (dim 1) and multiply by values
            # This creates a soft max weighting
            softmax_weights = F.softmax(pre_acts, dim=1)
            weighted_values = softmax_weights * pre_acts
            weighted_sum = weighted_values.sum(dim=1)
            out["target"] = weighted_sum.sum(dim=-1)  # sum along features

            out["per_token_target"] = pre_acts
            return resid

        # Use TransformerLens hook system
        with self.model.hooks(fwd_hooks=[(self.hook_name, get_target)]):
            logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        return out["target"], logits, {"per_token_target": out["per_token_target"]}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


class SumSAERunnerTLens(IRunner):
    def __init__(self, model, sae, features):
        """
        Initialize an SAERunner for TransformerLens models.

        Args:
            model: The TransformerLens model
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.features = features
        self.hook_name = sae.cfg.hook_name

        assert not model.cfg.post_embedding_ln, (
            "post embedding ln currently not supported"
        )

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(resid, hook):
            # Take mean across all tokens instead of just the last one
            inp_acts = resid

            pre_acts = []

            def store_pre_acts(acts: torch.Tensor, hook: str):
                pre_acts.append(acts)
                return acts

            self.sae.run_with_hooks(
                inp_acts,
                fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
            )
            mean_pre_acts = pre_acts[0][:, :, self.features]


            out["target"] = mean_pre_acts.mean(dim=1).sum(-1)

            out["per_token_target"] = mean_pre_acts
            return resid

        # Use TransformerLens hook system
        with self.model.hooks(fwd_hooks=[(self.hook_name, get_target)]):
            logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        return out["target"], logits, {"per_token_target": out["per_token_target"]}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)



class MaxWithSumSAERunnerTLens(IRunner):
    def __init__(self, model: HookedTransformer, sae: SAE, features: List[int]):
        """
        Initialize an SAERunner for TransformerLens models.

        Args:
            model: The TransformerLens model
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.features = features

        self.hook_name = sae.cfg.hook_name

        assert not model.cfg.post_embedding_ln, (
            "post embedding ln currently not supported"
        )

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(resid, hook):
            # Take mean across all tokens instead of just the last one
            inp_acts = resid

            pre_acts = []

            def store_pre_acts(acts: torch.Tensor, hook: str):
                pre_acts.append(acts)
                return acts

            self.sae.run_with_hooks(
                inp_acts,
                fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
            )
            pre_acts = pre_acts[0][:, :, self.features]

            # Apply softmax along seq length (dim 1) and multiply by values
            # This creates a soft max weighting
            softmax_weights = F.softmax(pre_acts, dim=1)
            weighted_values = softmax_weights * pre_acts
            weighted_sum = weighted_values.sum(dim=1)
            out["target"] = weighted_sum.sum(dim=-1)  # sum along features

            pre_acts_sum = pre_acts.sum(dim=-1)
            softmax_of_sum = F.softmax(pre_acts_sum, dim=-1)
            weighted_values_of_sum = softmax_of_sum * pre_acts_sum
            weighted_sum_of_sum = weighted_values_of_sum.sum(dim=1)
            out["target"] = out["target"] + weighted_sum_of_sum

            out["per_token_target"] = pre_acts
            return resid

        # Use TransformerLens hook system
        with self.model.hooks(fwd_hooks=[(self.hook_name, get_target)]):
            logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        return out["target"], logits, {"per_token_target": out["per_token_target"]}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)




class DivergenceRunner(IRunner):
    """
    Takes a runner and adds to the loss a term pushing embeddings away from stored vectors
    """

    def __init__(
        self,
        runner: IRunner,
        fixed_positions,
        divergence_weight=1.0,
        multiply_repulsions=False,
        normalize_residuals=True,
        use_fixed_positions=True,
        average_repulsions=False,
    ):
        """
        Initialize the wrapper with another runner.

        Args:
            runner: The underlying IRunner to wrap
            fixed_positions: The token positions to apply divergence to
            divergence_weight: Weight for the divergence term (default: 1.0)
            multiply_repulsions: If True, multiply repulsion terms instead of adding them
            normalize_residuals: Whether to normalize residuals before computing similarity (default: True)
        """
        self.runner = runner
        self.repulsion_vectors = []
        self.fixed_positions = fixed_positions
        self.divergence_weight = divergence_weight
        self.multiply_repulsions = multiply_repulsions
        self.normalize_residuals = normalize_residuals
        self.use_fixed_positions = use_fixed_positions,
        self.average_repulsions = average_repulsions
    def run_with_embeddings(
        self, input_embeddings: Tensor
    ) -> Tuple[
        Tensor,
        Tensor,
        Dict[str, Any],
    ]:
        """
        Runs the wrapped runner and adds a divergence penalty to push away from repulsion_vectors.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Same return signature as the wrapped runner
        """
        # Call the wrapped runner's implementation
        target, logits, additional_outputs = self.runner.run_with_embeddings(input_embeddings)

        # Process intermediate_residual if it exists
        if "intermediate_residual" in additional_outputs:
            intermediate_residual = additional_outputs["intermediate_residual"]

            # Only use the representations at fixed positions for divergence
            position_residuals = intermediate_residual[:, self.fixed_positions, :]

            # Initialize divergence penalty based on operation type
            if self.multiply_repulsions:
                divergence_penalty = torch.ones_like(target)
            else:
                divergence_penalty = torch.zeros_like(target)

            if self.repulsion_vectors and len(position_residuals) > 0:
                # Calculate divergence penalty for each stored repulsion vector
                for repulsion_vector in self.repulsion_vectors:
                    # Always normalize repulsion vectors
                    normalized_repulsion = F.normalize(repulsion_vector, p=2, dim=-1)

                    # Conditionally normalize position residuals
                    if self.normalize_residuals:
                        used_residuals = F.normalize(position_residuals, p=2, dim=-1)
                    else:
                        used_residuals = position_residuals

                    # Compute similarity
                    cosine_sim = torch.sum(
                        used_residuals * normalized_repulsion.unsqueeze(0).unsqueeze(1), dim=-1
                    )

                    # Average similarity across positions
                    avg_similarity = cosine_sim.mean(dim=1)

                    # Create repulsion term - we want (1 - similarity) to encourage dissimilarity
                    repulsion_term = (1.0 - avg_similarity)

                    # Either add or multiply based on configuration
                    if self.multiply_repulsions:
                        divergence_penalty *= repulsion_term
                    else:
                        divergence_penalty += repulsion_term

            # Add weighted divergence penalty to target
            if self.average_repulsions  and len(self.repulsion_vectors) > 0:
                div = self.divergence_weight * (1/len(self.repulsion_vectors))
            else:
                div = self.divergence_weight
            if self.multiply_repulsions:
                target = target * divergence_penalty
            else:
                target = target + div * divergence_penalty
        else:
            raise ValueError("intermediate_residual not found in additional_outputs")

        return target, logits, additional_outputs

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Delegate to the wrapped runner's implementation."""
        return self.runner.one_hot_to_embed(one_hot)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Delegate to the wrapped runner's implementation."""
        return self.runner.int_ids_to_embed(int_ids)
