from typing import List, Optional, Tuple, Union, Callable, Any

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from configuration_gemma2_SAE import Gemma2_SAE_Config
except:
    from .configuration_gemma2_SAE import Gemma2_SAE_Config

from transformers import GemmaPreTrainedModel, GemmaModel, GemmaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class Gemma2_SAE(GemmaForCausalLM):
    config_class = Gemma2_SAE_Config
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: Gemma2_SAE_Config):
        super().__init__(config)
        self.model = GemmaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Initialize weights and apply final processing

        if config.activation == "topk":
            if isinstance(config.activation_k, int):
                activation = TopK(torch.tensor(config.activation_k))
            else:
                activation = TopK(config.activation_k)
        elif config.activation == "topk-tanh":
            if isinstance(config.activation_k, int):
                activation = TopK(torch.tensor(config.activation_k), nn.Tanh())
            else:
                activation = TopK(config.activation_k, nn.Tanh())
        elif config.activation == "topk-sigmoid":
            if isinstance(config.activation_k, int):
                activation = TopK(torch.tensor(config.activation_k), nn.Sigmoid())
            else:
                activation = TopK(config.activation_k, nn.Sigmoid())
        elif config.activation == "jumprelu":
            activation = JumpReLu()
        elif config.activation == "relu":
            activation = "ReLU"
        elif config.activation == "identity":
            activation = "Identity"
        else:
            raise (
                NotImplementedError,
                f"Activation '{config.activation}' not implemented.",
            )

        self.SAE = Autoencoder(
            n_inputs=config.n_inputs,
            n_latents=config.n_latents,
            activation=activation,
            tied=False,
            normalize=True,
        )

        self.hook = HookedTransformer_with_SAE_suppresion(
            block=config.hook_block_num,
            sae=self.SAE,
            mod_features=config.mod_features,
            mod_threshold=config.mod_threshold,
            mod_replacement=config.mod_replacement,
            mod_scaling=config.mod_scaling,
        ).register_with(self.model, config.site)

        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, GemmaForCausalLM

        >>> model = GemmaForCausalLM.from_pretrained("meta-gemma/Gemma-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-gemma/Gemma-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss(reduction="none")
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            loss = loss.view(logits.size(0), -1)
            mask = loss != 0
            loss = loss.sum(dim=-1) / mask.sum(dim=-1)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


def LN(
    x: torch.Tensor, eps: float = 1e-5
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    mu = x.mean(dim=-1, keepdim=True)
    x = x - mu
    std = x.std(dim=-1, keepdim=True)
    x = x / (std + eps)
    return x, mu, std


class Autoencoder(nn.Module):
    """Sparse autoencoder

    Implements:
        latents = activation(encoder(x - pre_bias) + latent_bias)
        recons = decoder(latents) + pre_bias
    """

    def __init__(
        self,
        n_latents: int,
        n_inputs: int,
        activation: Callable = nn.ReLU(),
        tied: bool = False,
        normalize: bool = False,
    ) -> None:
        """
        :param n_latents: dimension of the autoencoder latent
        :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units)
        :param activation: activation function
        :param tied: whether to tie the encoder and decoder weights
        """
        super().__init__()
        self.n_inputs = n_inputs
        self.n_latents = n_latents

        self.pre_bias = nn.Parameter(torch.zeros(n_inputs))
        self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False)
        self.latent_bias = nn.Parameter(torch.zeros(n_latents))
        self.activation = activation

        if isinstance(activation, JumpReLu):
            self.threshold = nn.Parameter(torch.empty(n_latents))
            torch.nn.init.constant_(self.threshold, 0.001)
            self.forward = self.forward_jumprelu
        elif isinstance(activation, TopK):
            self.forward = self.forward_topk
        else:
            logger.warning(
                f"Using TopK forward function even if activation is not TopK, but is {activation}"
            )
            self.forward = self.forward_topk

        if tied:
            self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
            self.decoder.weight.data = self.encoder.weight.data.T.clone()
        else:
            self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
        self.normalize = normalize

    def encode_pre_act(
        self, x: torch.Tensor, latent_slice: slice = slice(None)
    ) -> torch.Tensor:
        """
        :param x: input data (shape: [batch, n_inputs])
        :param latent_slice: slice of latents to compute
            Example: latent_slice = slice(0, 10) to compute only the first 10 latents.
        :return: autoencoder latents before activation (shape: [batch, n_latents])
        """
        x = x - self.pre_bias
        latents_pre_act = F.linear(
            x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice]
        )
        return latents_pre_act

    def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
        if not self.normalize:
            return x, dict()
        x, mu, std = LN(x)
        return x, dict(mu=mu, std=std)

    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
        """
        :param x: input data (shape: [batch, n_inputs])
        :return: autoencoder latents (shape: [batch, n_latents])
        """
        x, info = self.preprocess(x)
        return self.activation(self.encode_pre_act(x)), info

    def decode(
        self, latents: torch.Tensor, info: dict[str, Any] | None = None
    ) -> torch.Tensor:
        """
        :param latents: autoencoder latents (shape: [batch, n_latents])
        :return: reconstructed data (shape: [batch, n_inputs])
        """
        ret = self.decoder(latents) + self.pre_bias
        if self.normalize:
            assert info is not None
            ret = ret * info["std"] + info["mu"]
        return ret

    def forward_topk(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        :param x: input data (shape: [batch, n_inputs])
        :return:  autoencoder latents pre activation (shape: [batch, n_latents])
                  autoencoder latents (shape: [batch, n_latents])
                  reconstructed data (shape: [batch, n_inputs])
        """
        x, info = self.preprocess(x)
        latents_pre_act = self.encode_pre_act(x)
        latents = self.activation(latents_pre_act)
        recons = self.decode(latents, info)

        return latents_pre_act, latents, recons

    def forward_jumprelu(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        :param x: input data (shape: [batch, n_inputs])
        :return:  autoencoder latents pre activation (shape: [batch, n_latents])
                  autoencoder latents (shape: [batch, n_latents])
                  reconstructed data (shape: [batch, n_inputs])
        """
        x, info = self.preprocess(x)
        latents_pre_act = self.encode_pre_act(x)
        latents = self.activation(F.relu(latents_pre_act), torch.exp(self.threshold))
        recons = self.decode(latents, info)

        return latents_pre_act, latents, recons


class TiedTranspose(nn.Module):
    def __init__(self, linear: nn.Linear):
        super().__init__()
        self.linear = linear

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert self.linear.bias is None
        return F.linear(x, self.linear.weight.t(), None)

    @property
    def weight(self) -> torch.Tensor:
        return self.linear.weight.t()

    @property
    def bias(self) -> torch.Tensor:
        return self.linear.bias


class TopK(nn.Module):
    def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None:
        super().__init__()
        self.k = k
        self.postact_fn = postact_fn

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        topk = torch.topk(x, k=self.k, dim=-1)
        values = self.postact_fn(topk.values)
        # make all other values 0
        result = torch.zeros_like(x)
        result.scatter_(-1, topk.indices, values)
        return result


class JumpReLu(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, threshold):
        return JumpReLUFunction.apply(input, threshold)


class HeavyStep(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, threshold):
        return HeavyStepFunction.apply(input, threshold)


def rectangle(x):
    return (x > -0.5) & (x < 0.5)


class JumpReLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(input, threshold):
        output = input * (input > threshold)
        return output

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, threshold = inputs
        ctx.save_for_backward(input, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        bandwidth = 0.001
        # bandwidth = 0.0001
        input, threshold = ctx.saved_tensors
        grad_input = grad_threshold = None

        grad_input = input > threshold
        grad_threshold = (
            -(threshold / bandwidth)
            * rectangle((input - threshold) / bandwidth)
            * grad_output
        )

        return grad_input, grad_threshold


class HeavyStepFunction(torch.autograd.Function):
    @staticmethod
    def forward(input, threshold):
        output = input * threshold
        return output

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, threshold = inputs
        ctx.save_for_backward(input, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        bandwidth = 0.001
        # bandwidth = 0.0001
        input, threshold = ctx.saved_tensors
        grad_input = grad_threshold = None

        grad_input = torch.zeros_like(input)
        grad_threshold = (
            -(1.0 / bandwidth)
            * rectangle((input - threshold) / bandwidth)
            * grad_output
        )

        return grad_input, grad_threshold


ACTIVATIONS_CLASSES = {
    "ReLU": nn.ReLU,
    "Identity": nn.Identity,
    "TopK": TopK,
    "JumpReLU": JumpReLu,
}




class HookedTransformer_with_SAE_suppresion:
    """Auxilliary class used to extract mlp activations from transformer models."""

    def __init__(
        self,
        block: int,
        sae: Autoencoder,
        mod_features: list = None,
        mod_threshold: list = None,
        mod_replacement: list = None,
        mod_scaling: list = None,
        mod_balance: bool = False,
        multi_feature: bool = False,
    ) -> None:
        self.block = block
        self.sae = sae

        self.remove_handle = (
            None  # Can be used to remove this hook from the model again
        )

        self._features = None
        self.mod_features = mod_features
        self.mod_threshold = mod_threshold
        self.mod_replacement = mod_replacement
        self.mod_scaling = mod_scaling
        self.mod_balance = mod_balance
        self.mod_vector = None
        self.mod_vec_factor = 1.0

        if isinstance(self.sae.activation, JumpReLu):
            logger.info("Setting __call__ function for JumpReLU.")
            setattr(self, "call", self.__call__jumprelu)
        elif isinstance(self.sae.activation, TopK):
            logger.info("Setting __call__ function for TopK.")
            setattr(self, "call", self.__call__topk)
        else:
            logger.warning(
                f"Using TopK forward function even if activation is not TopK, but is {self.sae.activation}"
            )
            setattr(self, "call", self.__call__topk)

    def register_with(self, model, site="mlp"):
        self.site = site
        # Decision on where to extract activations from
        if site == "mlp":  # output of the FF module of block
            self.remove_handle = model.layers[self.block].mlp.register_forward_hook(
                self
            )
        elif (
            site == "block"
        ):  # output of the residual connection AFTER it is added to the FF output
            self.remove_handle = model.layers[self.block].register_forward_hook(self)
        elif site == "attention":
            raise NotImplementedError
        else:
            raise NotImplementedError


        return self

    def pop(self) -> torch.Tensor:
        """Remove and return extracted feature from this hook.

        We only allow access to the features this way to not have any lingering references to them.
        """
        assert self._features is not None, "Feature extractor was not called yet!"
        if isinstance(self._features, tuple):
            features = self._features[0]
        else:
            features = self._features
        self._features = None
        return features

    def __call__topk(self, module, inp, outp) -> torch.Tensor:
        self._features = outp
        if isinstance(self._features, tuple):
            features = self._features[0]
        else:
            features = self._features

        if self.mod_features is None:
            recons = features
        else:
            x, info = self.sae.preprocess(features)
            latents_pre_act = self.sae.encode_pre_act(x)
            latents = self.sae.activation(latents_pre_act)
           

            recons = self.sae.decode(latents, info)

        if isinstance(self._features, tuple):
            outp = list(outp)
            outp[0] = recons
            return tuple(outp)
        else:
            return recons

    def __call__jumprelu(self, module, inp, outp) -> torch.Tensor:
        self._features = outp
        if self.mod_features is None:
            recons = outp
        else:
            x, info = self.sae.preprocess(outp)
            latents_pre_act = self.sae.encode_pre_act(x)
            latents = self.sae.activation(
                F.relu(latents_pre_act), torch.exp(self.sae.threshold)
            )
            recons = self.sae.decode(latents, info)

        return recons

    def __call__(self, module, inp, outp) -> torch.Tensor:
        return self.call(module, inp, outp)
