# Inspired by https://github.com/zhangmengling/LLMScan/tree/main
import gc
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union

import torch
from jaxtyping import Float
from tqdm import tqdm, trange

import transformers
from code_demeanor.logger import logger
from code_demeanor.reading.schemas import BypassedOutput, ScanningType
from code_demeanor.utils import (
    _fro_norm_per_batch,
    _stat_features,
    load_tensors_jsonl,
    save_tensor_jsonl,
)
from code_demeanor.utils.flops import (
    _human,
    estimate_scan_flops,
    spec_layer_wise,
    spec_microsaccades,
    spec_token_wise_per_string,
)

N_STATS = 5  # mean, std, range, skewness, kurtosis


class Scanner(ABC):
    """Abstract base class for LLM scanners."""

    def __init__(
        self,
        model,
        tokenizer,
        scanning_type: ScanningType,
        hidden_layers: list = None,
        heads: list = None,
        seq_length: int = None,
        device: str = "cpu",
        verbose: bool = False,
        replacement_token: str = "-",
        batch_size: int = 1,
        use_cache: bool = True,
        combine_causal_effects: bool = True,
        shadow_run: bool = False,
        assume_full_sequence_lm_head: bool = False,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.verbose = verbose
        self.scanning_type = scanning_type
        self.hidden_layers = hidden_layers if hidden_layers is not None else []
        self.heads = heads if heads is not None else []
        self.seq_length = seq_length
        self.replacement_token = replacement_token

        self.replacement_token_id = self.tokenizer.convert_tokens_to_ids(
            self.replacement_token
        )
        self.combine_causal_effects = combine_causal_effects
        self.batch_size = batch_size
        self.use_cache = use_cache
        self._cache = {}

        # Purely utils for flops estimation
        self.shadow_run = shadow_run
        self.assume_full_sequence_lm_head = assume_full_sequence_lm_head

        logger.info(
            f"Initialized Scanner with scanning_type={self.scanning_type}, "
            f"hidden_layers={self.hidden_layers}, seq_length={self.seq_length}, "
            f"device={self.device}, replacement_token_id={self.replacement_token_id}, "
            f"heads={self.heads}, model_name={self.model.name_or_path}, "
            f", batch_size={self.batch_size}, combine_causal_effects={self.combine_causal_effects}"
        )

    @torch.no_grad()
    def scan_with_layer_patch_base(
        self,
        input_ids: Float[torch.Tensor, "batch seq_len"],
        attention_mask: Float[torch.Tensor, "batch seq_len"],
        layer_idx_to_bypass: int,
    ) -> BypassedOutput:
        """Scan while bypassing the specified layer and capture attention scores."""
        # Ensure the model returns attentions
        self.model.config.output_attentions = True
        # Get embeddings
        embedding_output = self.model.get_input_embeddings()(input_ids)

        hidden_states: Float[torch.Tensor, "batch seq_len hidden_size"] = (
            embedding_output
        )
        all_hidden_states = [hidden_states]
        all_attentions = (
            []
        )  # will hold per-layer attention maps: [batch, heads, seq, seq]

        # GPT-2 expects an additive mask of shape [batch, 1, 1, seq_len]
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype=hidden_states.dtype)
            attention_mask = (1.0 - attention_mask) * -1e4

        if isinstance(self.model, transformers.GPT2Model):
            all_layers = self.model.h
        elif isinstance(self.model, transformers.GPT2LMHeadModel):
            all_layers = self.model.transformer.h
        elif isinstance(
            self.model, transformers.models.llama.modeling_llama.LlamaForCausalLM
        ):
            all_layers: List[
                transformers.models.llama.modeling_llama.LlamaDecoderLayer
            ] = self.model.model.layers
            rotary: transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
                self.model.model.rotary_emb
            )
        elif isinstance(self.model, transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM):
            all_layers: List[transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer] = self.model.model.layers
            rotary: transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = (self.model.model.rotary_emb)
        else:
            raise ValueError("Unsupported model type for layer-wise scanning.")

        for layer_idx, layer in enumerate(all_layers):
            if layer_idx == layer_idx_to_bypass:
                if self.verbose:
                    logger.info(f"Bypassing layer {layer_idx}")
                # keep alignment with other lists
                all_attentions.append(None)
                continue

            layer_kwargs = {}

            if isinstance(
                self.model, transformers.models.llama.modeling_llama.LlamaForCausalLM
            ) or isinstance(
                self.model, transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM
            ):
                position_ids = (
                    torch.arange(hidden_states.size(1), device=hidden_states.device)
                    .unsqueeze(0)
                    .expand(hidden_states.size(0), -1)
                    .to(torch.int64)
                )
                cos, sin = rotary.forward(x=hidden_states, position_ids=position_ids)
                layer_kwargs["position_embeddings"] = (cos, sin)

            layer_out: Tuple[
                Float[torch.Tensor, "batch seq_len hidden_size"],
                Float[torch.Tensor, "batch num_heads seq_len seq_len"],
            ] = layer(
                hidden_states,
                attention_mask=attention_mask,
                use_cache=False,
                output_attentions=True,
                **layer_kwargs,
            )

            if len(layer_out) == 2:
                hidden_states: Float[torch.Tensor, "batch seq_len hidden_size"] = layer_out[
                    0
                ]
                attn_weights: Float[torch.Tensor, "batch num_heads seq_len seq_len"] = (
                    layer_out[1]
                )
            elif len(layer_out) == 1:
                hidden_states: Float[torch.Tensor, "batch seq_len hidden_size"] = layer_out
                attn_weights = torch.tensor([0])

            all_hidden_states.append(hidden_states)
            all_attentions.append(attn_weights)

            if self.verbose:
                logger.info(
                    f"Layer {layer_idx} output {hidden_states.shape}, "
                    f"attn {None if attn_weights is None else tuple(attn_weights.shape)}"
                )

        final_hidden = hidden_states  # [batch, seq_len, hidden_size]
        if self.verbose:
            logger.info(
                f"Final hidden state shape after bypassing layer {layer_idx}: {final_hidden.shape}"
            )

        logits = self._last_linear_forward(final_hidden)  # [batch, seq_len, vocab]
        next_token_logits = logits[:, -1, :]
        probs = torch.softmax(next_token_logits, dim=-1)

        # Pass all relevant outputs to cpu!
        input_ids = input_ids.cpu()
        probs = probs.cpu()
        all_hidden_states = [h.cpu() for h in all_hidden_states]
        all_attentions = [a.cpu() if a is not None else None for a in all_attentions]

        return BypassedOutput(
            input_ids=input_ids,
            patched_probs=probs,
            all_hidden_states=all_hidden_states,
            all_attentions=all_attentions,  # <-- add this to your dataclass
        )

    def _get_logits(
        self, outputs: Union["BaseModelOutputWithPastAndCrossAttentions"]
    ) -> Float[torch.Tensor, "batch seq_len vocab_size"]:
        """Get logits from the hidden states."""
        from transformers.modeling_outputs import (
            BaseModelOutputWithPastAndCrossAttentions,
        )

        if isinstance(outputs, BaseModelOutputWithPastAndCrossAttentions):
            # For models with past key values, we only need the hidden states
            logits = outputs.last_hidden_state  # shape: (batch, seq_len, hidden_size)
        else:
            # For models without past key values, we can use the logits directly
            logits = outputs.logits

        # Get probabilities for next token (last position)
        next_token_logits = logits[:, -1, :]  # shape: (batch, vocab_size)
        next_token_probs = torch.softmax(
            next_token_logits, dim=-1
        )  # shape: (batch, vocab_size)
        if self.verbose:
            logger.info(f"Next token probabilities shape: {next_token_probs.shape}")
        return next_token_logits, next_token_probs

    @torch.no_grad()
    def _llm_forward(self, code: Union[str, List[str]]) -> Tuple[
        Float[torch.Tensor, "batch seq_len hidden_size"],
        Dict[str, torch.Tensor],
        Float[torch.Tensor, "batch seq_len vocab_size"],
        Float[torch.Tensor, "batch num_heads seq_len seq_len"],
    ]:
        """Forward pass through the LLM to get hidden states and next-token probabilities."""
        inputs = self._tokenize_code(code)
        return self._llm_forward_base(inputs)

    @torch.no_grad()
    def _llm_forward_tokens(
        self, tokens: Float[torch.Tensor, "batch seq_len"]
    ) -> Tuple[
        Float[torch.Tensor, "batch seq_len hidden_size"],
        Dict[str, torch.Tensor],
        Float[torch.Tensor, "batch seq_len vocab_size"],
        Float[torch.Tensor, "batch num_heads seq_len seq_len"],
    ]:
        """Forward pass through the LLM to get hidden states for token inputs."""
        return self._llm_forward_base(tokens)

    @torch.no_grad()
    def _llm_forward_base(self, inputs: Dict[str, torch.Tensor]) -> Tuple[
        Float[torch.Tensor, "batch seq_len hidden_size"],
        Dict[str, torch.Tensor],
        Float[torch.Tensor, "batch seq_len vocab_size"],
        Float[torch.Tensor, "batch num_heads seq_len seq_len"],
    ]:
        """Forward pass through the LLM to get hidden states for base inputs."""
        outputs = self.model(
            **inputs,
            output_hidden_states=True,
            return_dict=True,
            output_attentions=True,
        )
        attention_scores = (
            outputs.attentions if hasattr(outputs, "attentions") else None
        )
        assert (
            attention_scores is not None
        ), "Model does not return attention scores. Ensure output_attentions=True in model config."
        hidden_states = outputs.hidden_states
        _, next_token_probs = self._get_logits(outputs)

        return hidden_states, next_token_probs, inputs, attention_scores

    def _last_linear_forward(
        self, last_hidden_state: Float[torch.Tensor, "batch seq_len hidden_size"]
    ) -> Float[torch.Tensor, "batch seq_len vocab_size"]:
        """Forward pass through the last linear layer to get next-token probabilities."""
        layer_normed = None
        # If the model has a final layer norm, apply it
        if hasattr(self.model, "lm_head"):
            if (
                type(self.model)
                is transformers.models.llama.modeling_llama.LlamaForCausalLM
                or type(self.model)
                is transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM
            ):
                # For LlamaForCausalLM, apply the final layer norm before the LM head
                # Aoly the RMSNorm
                assert hasattr(self.model, "model") and hasattr(
                    self.model.model, "norm"
                ), "LlamaForCausalLM model does not have a 'model.norm' attribute."
                layer_normed = self.model.model.norm(last_hidden_state)
            elif (
                type(self.model)
                is transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel
            ):
                assert hasattr(self.model, "transformer") and hasattr(
                    self.model.transformer, "ln_f"
                ), "GPT2LMHeadModel does not have a 'transformer.ln_f' attribute."
                layer_normed = self.model.transformer.ln_f(last_hidden_state)
            else:
                # If no final layer norm, just use the final hidden state
                if self.verbose:
                    logger.info(
                        "No final layer norm found, using final hidden state directly."
                    )
        else:
            # If the model does not have a lm_head, we cannot compute next-token probabilities
            raise ValueError(
                "Model does not have a lm_head for next-token probabilities."
            )
        if layer_normed is None:
            raise ValueError(
                "Layer normed output is None. Check the model architecture."
            )

        if hasattr(self.model, "lm_head"):
            w = self.model.lm_head.weight
            layer_normed = layer_normed.to(dtype=w.dtype, device=w.device)
            logits = self.model.lm_head(
                layer_normed
            )  # shape: [batch, seq_len, vocab_size]
        else:
            # If the model does not have a lm_head, we cannot compute next-token probabilities
            raise ValueError(
                "Model does not have a lm_head for next-token probabilities."
            )

        if logits is None:
            raise ValueError("Logits are None. Check the model architecture.")

        if self.verbose:
            logger.info(f"Logits shape: {logits.shape}")
        return logits

    def _get_next_token_probabilities(
        self, hidden_states: Tuple[torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute the next-token probabilities from the last hidden state.
        """
        # Get the final hidden state from the last transformer block

        final_hidden = hidden_states[-1]  # shape: [batch, seq_len, hidden_size]
        if self.verbose:
            logger.info(f"Final hidden state shape: {final_hidden.shape}")

        logits = self._last_linear_forward(
            final_hidden
        )  # shape: [batch, seq_len, vocab_size]

        # Extract the logits for the last token in the sequence
        next_token_logits = logits[:, -1, :]  # shape: [batch, vocab_size]

        # Convert logits to probabilities
        probs = torch.softmax(next_token_logits, dim=-1)  # shape: [batch, vocab_size]
        return probs

    def _get_token_ids(self, code: str) -> Float[torch.Tensor, "batch seq_len"]:
        """Convert code to token IDs."""
        inputs = self._tokenize_code(code)
        return inputs["input_ids"]

    def _tokenize_code(
        self, code: Union[str, List[str]]
    ) -> Float[torch.Tensor, "batch seq_len"]:
        """Tokenize the code and return input IDs."""
        if self.tokenizer.pad_token_id is None:
            # If pad_token_id is not set, use eos_token_id as pad_token_id
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        if self.seq_length is not None:
            tokenizer_kwargs = {
                "padding": "max_length",
                "truncation": True,
                "max_length": self.seq_length,
            }
        else:
            tokenizer_kwargs = {}

        return self.tokenizer(
            code,
            return_tensors="pt",
            **tokenizer_kwargs,
        ).to(self.device)

    def _get_tokens(self, code: str) -> Float[torch.Tensor, "batch seq_len"]:
        """Convert code to tokens."""
        inputs = self._tokenize_code(code)
        return inputs, inputs["input_ids"].squeeze(0)

    def scan(self, code: Union[str, List[str]]) -> Tuple[
        Float[torch.Tensor, "n_examples n_layers n_stats"],
        Float[torch.Tensor, "n_examples n_layers"],
    ]:
        """Scan the code and return the layer-wise differences."""
        if isinstance(code, str):
            code = [code]

        all_stats: Float[torch.Tensor, "n_examples n_layers n_stats"] = None
        all_causal_effects: Float[torch.Tensor, "n_examples n_layers"] = None
        causal_effects = None

        for code_idx, code_entry in tqdm(enumerate(code)):
            assert isinstance(
                code_entry, str
            ), f"Each code entry must be a string, got {type(code_entry)}"

            causal_effects, stats = self.scan_function([code_entry])

            if code_idx == 0:
                if stats is not None:
                    all_stats = stats
                if causal_effects is not None:
                    all_causal_effects = causal_effects
            else:
                if all_stats is not None:
                    all_stats = torch.cat((all_stats, stats), dim=0)
                if causal_effects is not None:
                    all_causal_effects = torch.cat(
                        (all_causal_effects, causal_effects), dim=0
                    )

        return all_stats, all_causal_effects

    @abstractmethod
    def scan_function(self, code: Union[str, List[str]]) -> Tuple[
        Float[torch.Tensor, "n_examples n_layers n_stats"],
        Float[torch.Tensor, "n_examples n_layers"],
    ]:
        pass


class MicrosaccadesScanner(Scanner):
    def __init__(
        self,
        model,
        tokenizer,
        scanning_type: ScanningType,
        hidden_layers: list = None,
        heads: list = None,
        seq_length: int = None,
        device: str = "cpu",
        verbose: bool = False,
        replacement_token: str = "-",
        batch_size: int = 1,
        use_cache: bool = True,
        combine_causal_effects: bool = True,
        shadow_run: bool = False,
        assume_full_sequence_lm_head: bool = False,
        random_positional_encoding: bool = False,
        gaussian_positional_encoding: bool = False,
    ):
        super().__init__(
            model,
            tokenizer,
            scanning_type,
            hidden_layers,
            heads,
            seq_length,
            device,
            verbose,
            replacement_token,
            batch_size,
            use_cache,
            combine_causal_effects,
            shadow_run,
            assume_full_sequence_lm_head,
        )
        self.random_positional_encoding = random_positional_encoding
        self.gaussian_positional_encoding = gaussian_positional_encoding

    def _get_random_positional_encoding_intervention(
        self, positional_encoding: Float[torch.Tensor, "batch seq_len hidden_size"]
    ) -> Float[torch.Tensor, "batch seq_len hidden_size"]:
        # Add some random positional encoding

        # Get shapes
        batch_size, seq_len, hidden_size = positional_encoding.size()

        # Random noise
        noise = (
            torch.randn(
                (batch_size, seq_len, hidden_size), device=positional_encoding.device
            )
            * 0.01
        )  # small noise

        positional_encoding = positional_encoding + noise
        return positional_encoding

    def _get_gaussian_positional_encoding_intervention(
        self, positional_encoding: Float[torch.Tensor, "batch seq_len hidden_size"]
    ) -> Float[torch.Tensor, "batch seq_len hidden_size"]:
        # Add some Gaussian positional encoding

        # Get shapes
        batch_size, seq_len, hidden_size = positional_encoding.size()

        # Position indices [seq_len]
        position_ids = torch.arange(seq_len, device=positional_encoding.device).float()

        # Gaussian parameters
        mean = seq_len / 2
        std = seq_len / 6  # covers ~99.7% within the sequence length

        # Gaussian encoding [seq_len]
        gauss = torch.exp(-0.5 * ((position_ids - mean) / std) ** 2)
        gauss = gauss.unsqueeze(1).expand(-1, hidden_size)  # [seq_len, hidden_size]

        # Expand to batch and add
        positional_encoding = positional_encoding + gauss.unsqueeze(0).expand(
            batch_size, -1, -1
        )
        return positional_encoding

    def _get_positional_encoding_intervention(
        self, positional_encoding: Float[torch.Tensor, "batch seq_len hidden_size"]
    ) -> Float[torch.Tensor, "batch seq_len hidden_size"]:
        # Add some cosine positional encoding

        # Get shapes
        batch_size, seq_len, hidden_size = positional_encoding.size()

        # Position indices [seq_len, 1]
        position_ids = torch.arange(
            seq_len, device=positional_encoding.device
        ).unsqueeze(1)

        # Frequencies [hidden_size // 2]
        div_term = torch.exp(
            torch.arange(0, hidden_size, 2, device=positional_encoding.device)
            * -(torch.log(torch.tensor(10000.0)) / hidden_size)
        )

        # Sinusoidal encoding [seq_len, hidden_size]
        pe = torch.zeros(seq_len, hidden_size, device=positional_encoding.device)
        pe[:, 0::2] = torch.sin(position_ids * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position_ids * div_term)  # odd indices

        # Expand to batch and add
        positional_encoding = positional_encoding + pe.unsqueeze(0).expand(
            batch_size, -1, -1
        )
        return positional_encoding

    @torch.no_grad()
    def scan_microsaccade(self, code: Union[str, List[str]]) -> Tuple[
        Float[torch.Tensor, "n_layers batch_size"],
        Float[torch.Tensor, "n_layers batch_size n_stats"],
    ]:

        if isinstance(code, str):
            code = [code]

        if self.shadow_run:
            seq_len = len(self._tokenize_code(code)["input_ids"][0])
            spec = spec_microsaccades(batch_size=len(code), seq_len=seq_len)
            flops = estimate_scan_flops(
                self.model,
                seq_len=spec.passes[0][1],
                scan_spec=spec,
                assume_full_sequence_lm_head=self.assume_full_sequence_lm_head,
            )
            logger.info(f"[ShadowRun] Microsaccades FLOPs ≈ {_human(flops)}")
            return torch.tensor([flops]), torch.empty(0)

        hidden_states, outputs, inputs, attention_matrices = self._llm_forward(code)
        next_token_probs = self._get_next_token_probabilities(hidden_states)
        next_token_probs = next_token_probs.detach().cpu()

        # If no layers specified, default to all except embedding
        if not getattr(self, "hidden_layers", []):
            self.hidden_layers = list(range(len(hidden_states) - 1))

        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask", None)

        # Get embeddings
        if isinstance(self.model, transformers.GPT2Model):
            token_embeddings: Float[torch.Tensor, "batch seq_len hidden_size"] = (
                self.model._modules["transformer"].wte(input_ids)
            )
            positional_info: Float[torch.Tensor, "batch seq_len hidden_size"] = (
                self.model._modules["transformer"].wpe(
                    torch.arange(input_ids.size(1), device=input_ids.device)
                )
            )
            embedding_output = token_embeddings + positional_info
            del token_embeddings, positional_info
        else:
            # logger.warning(
            #     "Using standard embedding layer; positional embeddings not added separately."
            # )
            embedding_output = self.model.get_input_embeddings()(input_ids)

        if self.random_positional_encoding:
            embedding_output = self._get_random_positional_encoding_intervention(
                embedding_output
            )
        elif self.gaussian_positional_encoding:
            embedding_output = self._get_gaussian_positional_encoding_intervention(
                embedding_output
            )
        else:
            embedding_output = self._get_positional_encoding_intervention(
                embedding_output
            )

        hidden_states: Float[torch.Tensor, "batch seq_len hidden_size"] = (
            embedding_output
        )
        all_hidden_states = [hidden_states]
        all_attentions = (
            []
        )  # will hold per-layer attention maps: [batch, heads, seq, seq]

        # GPT-2 expects an additive mask of shape [batch, 1, 1, seq_len]
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype=hidden_states.dtype)
            attention_mask = (1.0 - attention_mask) * -1e4

        if isinstance(self.model, transformers.GPT2Model):
            all_layers = self.model.h
        elif isinstance(self.model, transformers.GPT2LMHeadModel):
            all_layers = self.model.transformer.h
        elif isinstance(self.model, transformers.models.llama.modeling_llama.LlamaForCausalLM):
            all_layers: List[
                transformers.models.llama.modeling_llama.LlamaDecoderLayer
            ] = self.model.model.layers
            rotary: transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
                self.model.model.rotary_emb
            )
        elif isinstance(self.model, transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM):
            all_layers: List[transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer] = self.model.model.layers
            rotary: transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = (self.model.model.rotary_emb)
        else:
            raise ValueError("Unsupported model type for layer-wise scanning.")

        for layer_idx, layer in enumerate(all_layers):
            layer_kwargs = {}

            if isinstance(
                self.model, transformers.models.llama.modeling_llama.LlamaForCausalLM
            ) or isinstance(
                self.model, transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM
            ):
                position_ids = (
                    torch.arange(hidden_states.size(1), device=hidden_states.device)
                    .unsqueeze(0)
                    .expand(hidden_states.size(0), -1)
                    .to(torch.int64)
                )
                cos, sin = rotary.forward(x=hidden_states, position_ids=position_ids)
                layer_kwargs["position_embeddings"] = (cos, sin)

            layer_out: Tuple[
                Float[torch.Tensor, "batch seq_len hidden_size"],
                Float[torch.Tensor, "batch num_heads seq_len seq_len"],
            ] = layer(
                hidden_states,
                attention_mask=attention_mask,
                use_cache=False,
                output_attentions=True,
                **layer_kwargs,
            )
            if len(layer_out) == 2:
                hidden_states: Float[torch.Tensor, "batch seq_len hidden_size"] = layer_out[
                    0
                ]
                attn_weights: Float[torch.Tensor, "batch num_heads seq_len seq_len"] = (
                    layer_out[1]
                )
            elif len(layer_out) == 1:
                hidden_states: Float[torch.Tensor, "batch seq_len hidden_size"] = layer_out
                attn_weights = torch.tensor([0])


            all_hidden_states.append(hidden_states)
            all_attentions.append(attn_weights)

            if self.verbose:
                logger.info(
                    f"Layer {layer_idx} output {hidden_states.shape}, "
                    f"attn {None if attn_weights is None else tuple(attn_weights.shape)}"
                )

        final_hidden = hidden_states  # [batch, seq_len, hidden_size]
        if self.verbose:
            logger.info(
                f"Final hidden state shape after bypassing layer {layer_idx}: {final_hidden.shape}"
            )
        logits = self._last_linear_forward(final_hidden)  # [batch, seq_len, vocab]
        next_token_logits = logits[:, -1, :]
        probs = torch.softmax(next_token_logits, dim=-1)

        # Pass all relevant outputs to cpu!
        input_ids = input_ids.cpu()
        probs = probs.cpu()
        all_hidden_states = [h.cpu() for h in all_hidden_states]
        all_attentions = [a.cpu() if a is not None else None for a in all_attentions]

        patched_output = BypassedOutput(
            input_ids=input_ids,
            patched_probs=probs,
            all_hidden_states=all_hidden_states,
            all_attentions=all_attentions,  # <-- add this to your dataclass
        )

        layer_effects = []

        # Compute L2 norm across vocab dim (and any others except batch)
        diff_probs = patched_output.patched_probs - next_token_probs
        fro_per_batch = _fro_norm_per_batch(diff_probs)
        layer_effects.append(fro_per_batch)

        # Free GPU memory
        torch.cuda.empty_cache()
        # Move to cpu
        attention_matrices = [a.cpu() for a in attention_matrices]

        # do the same for the attention matrices if available
        if all_attentions:
            for atten_idx, attn in enumerate(patched_output.all_attentions):
                diff_attn: Float[
                    torch.Tensor, "batch_size heads hidden_size hidden_size"
                ] = (attn - attention_matrices[atten_idx])

                for head in range(diff_attn.size(1)):
                    head_diff = diff_attn[:, head, :, :]
                    fro_per_batch_head = _fro_norm_per_batch(head_diff)
                    layer_effects.append(fro_per_batch_head)

        layer_effects: Float[torch.Tensor, "batch_size n_layers"] = torch.stack(
            layer_effects
        )
        # Transpose to match the expected output shape
        layer_effects: Float[torch.Tensor, "n_layers batch_size"] = torch.einsum(
            "b l -> l b", layer_effects
        )
        return layer_effects, None

    def scan_function(self, code_entry):
        causal_effects, stats = self.scan_microsaccade(code_entry)
        return causal_effects, None


class AttentionScanner(Scanner):
    def __init__(
        self,
        model,
        tokenizer,
        scanning_type: ScanningType,
        hidden_layers: list = None,
        heads: list = None,
        seq_length: int = None,
        device: str = "cpu",
        verbose: bool = False,
        replacement_token: str = "-",
        batch_size: int = 1,
        use_cache: bool = True,
        combine_causal_effects: bool = True,
        shadow_run: bool = False,
        assume_full_sequence_lm_head: bool = False,
    ):
        super().__init__(
            model,
            tokenizer,
            scanning_type,
            hidden_layers,
            heads,
            seq_length,
            device,
            verbose,
            replacement_token,
            batch_size,
            use_cache,
            combine_causal_effects,
            shadow_run,
            assume_full_sequence_lm_head,
        )

    @torch.no_grad()
    def scan_with_token_replacement(
        self, code: List[str], layer_idx_to_bypass: int, token_to_replace_idx: List[int]
    ) -> BypassedOutput:
        inputs = self._tokenize_code(code)
        input_ids = inputs["input_ids"]  # [1, T] on self.device
        attention_mask = inputs.get("attention_mask", None)

        T = input_ids.size(1)
        idx = torch.as_tensor(token_to_replace_idx, device=input_ids.device)
        # normalize negatives once, validate once
        idx = idx % T
        assert (idx >= 0).all() and (
            idx < T
        ).all(), (
            f"token_to_replace_idx={token_to_replace_idx} out of range for seq_len {T}"
        )

        B = idx.numel()
        # expand without clone; we only mutate selected elements next
        new_input_ids = input_ids.expand(B, -1).clone()
        row = torch.arange(B, device=new_input_ids.device)
        new_input_ids[row, idx] = self.replacement_token_id

        new_attention_mask = None
        if attention_mask is not None:
            new_attention_mask = attention_mask.expand(B, -1).clone()

        # Free GPU memory
        del inputs, input_ids, attention_mask

        return self.scan_with_layer_patch_base(
            new_input_ids, new_attention_mask, layer_idx_to_bypass
        )

    @torch.no_grad()
    def scan_attention(self, code: List[str]) -> Tuple[
        Float[torch.Tensor, "n_tokens n_layers"],
        Float[torch.Tensor, "n_tokens n_layers n_stats"],
    ]:
        """Scan the code at the token level."""
        if self.shadow_run:
            # tokenized lengths per string
            toks = [self._tokenize_code(c)["input_ids"].shape[1] for c in code]
            bs = max(1, self.batch_size)
            total = 0.0
            for T in toks:
                spec = spec_token_wise_per_string(seq_len=T, batch_tokens=bs)
                total += estimate_scan_flops(
                    self.model,
                    seq_len=T,
                    scan_spec=spec,
                    assume_full_sequence_lm_head=self.assume_full_sequence_lm_head,
                )
            logger.info(
                f"[ShadowRun] TokenWise FLOPs for {len(code)} strings ≈ {_human(total)}"
            )
            return torch.tensor([total]), torch.tensor([total])

        if isinstance(code, str):
            code = [code]
            logger.warning(
                "Code input is a string. For token-wise scanning, it should be a list of strings."
            )

        # TODO(XXXX-2): This is not great. I want the batches to be on token level analysis,
        # but the scanner is designed for layer-wise scanning.
        # Pre-compute all token lengths to optimize batching
        tokenized = [self._tokenize_code(c)["input_ids"][0] for c in code]
        seq_lens = [len(t) for t in tokenized]

        # Initialize output tensors
        total_tokens = sum(seq_lens)
        num_layers = (
            len(self.hidden_layers)
            if self.hidden_layers
            else self.model.config.num_hidden_layers
        )
        batch_size = len(code)
        all_causal_effects = torch.zeros(
            (batch_size, total_tokens, num_layers), device=self.device
        )
        all_stats = torch.zeros((batch_size, N_STATS), device=self.device)

        for i, c in enumerate(code):
            token_effects, stats = self.scan_attention_base(c)
            all_causal_effects[i, : seq_lens[i]] = token_effects
            all_stats[i, : seq_lens[i]] = stats

        return all_causal_effects, all_stats

    @torch.no_grad()
    def scan_attention_base(self, code: str) -> Tuple[
        Float[torch.Tensor, "n_tokens n_layers"],
        Float[torch.Tensor, "n_stats"],
    ]:
        """Optimized token-level scanning with caching and batch processing."""
        tokens, tokens_ids = self._get_tokens(code)
        seq_len = len(tokens_ids)

        # Get original attention matrices once
        cache_key = f"original_{hash(code)}"
        if self.use_cache and cache_key in self._cache:
            attention_matrices = self._cache[cache_key]
        else:
            with torch.inference_mode():
                _, _, _, attention_matrices = self._llm_forward_tokens(tokens)
                if self.use_cache:
                    self._cache[cache_key] = attention_matrices

        A_orig = torch.stack(attention_matrices, dim=0).to(
            self.device
        )  # [L, 1, H, T, T]
        chosen_layers = (
            self.hidden_layers
            if getattr(self, "hidden_layers", None)
            else list(range(A_orig.size(0)))
        )
        heads = getattr(self, "heads", None)

        # Pre-allocate tensors for results
        if self.combine_causal_effects:
            token_effects = torch.zeros(
                (seq_len, len(chosen_layers)), device=self.device
            )
        else:
            token_effects = torch.zeros(
                (seq_len, len(chosen_layers), len(heads)), device=self.device
            )
        causal_effect_token = torch.zeros(seq_len, device=self.device)

        # Select layers/heads once for the original
        A0 = A_orig[chosen_layers]  # [L, 1, H, T, T]
        if heads is not None:
            A0 = A0[:, :, heads]

        bs = self.batch_size

        for start in trange(0, seq_len, bs):
            end = min(start + bs, seq_len)
            batch_token_idxs = list(range(start, end))

            with torch.inference_mode():
                patched = self.scan_with_token_replacement(
                    code,
                    layer_idx_to_bypass=float("inf"),
                    token_to_replace_idx=batch_token_idxs,
                )
            A_intv = torch.stack(patched.all_attentions, dim=0).to(
                self.device
            )  # [L_total, B, H, T, T]

            del patched

            # Select layers, heads once (keeps contiguity)
            try:
                idx = torch.tensor(chosen_layers, dtype=torch.long, device="cpu")
                A1 = A_intv[idx]  # avoid Python list here; use a LongTensor index
            except Exception as e:
                logger.error(
                    f"Error selecting layers: {e}. Will move everything to cpu"
                )
                # Important: send to cpu! Slicing will fail otherwise!
                A_intv = A_intv.to("cpu")
                A1 = A_intv[chosen_layers]  # [L, B, H, T, T]
                A1 = A1.to(self.device)

            del A_intv
            gc.collect()

            if heads is not None:
                A1 = A1[:, :, heads]

            diff: Float[torch.Tensor, "layers batch heads seq_len seq_len"] = A1 - A0
            del A1
            gc.collect()

            if self.combine_causal_effects:
                # Frobenius (L2 norm) per batch per layer across
                fro: Float[torch.Tensor, "layers batch"] = (
                    diff.pow(2).sum(dim=(2, 3, 4)).sqrt()
                )
            else:
                # Compute the L2 Norm per layer and per head
                fro: Float[torch.Tensor, "layers batch heads"] = (
                    diff.pow(2).sum(dim=(3, 4)).sqrt()
                )

            # Write back for this token slice
            token_slice = slice(start, end)
            if self.combine_causal_effects:
                token_effects[token_slice] = fro.transpose(0, 1)  # [B, L]
                causal_effect_token[token_slice] = fro.pow(2).sum(dim=0).sqrt()  # [B]
            else:
                token_effects[token_slice] = torch.einsum("lbh->blh", fro)
                causal_effect_token[token_slice] = fro.pow(2).sum(dim=(0, 2)).sqrt()

            del fro
            gc.collect()

        del A0

        intervened_moments: Float[torch.Tensor, "n_stats"] = _stat_features(
            causal_effect_token.unsqueeze(0)
        )[0]

        return token_effects, intervened_moments

    def scan_function(self, code_entry):
        _, stats = self.scan_attention(code_entry)
        return None, stats


class LayerWiseScanner(Scanner):
    """Layer-wise scanner for LLMs.

    We are looking to evaluate the effect of a specific layer has on next token prediction.
    This scanner will return the hidden states of all layers for the given code input.
    """

    def __init__(
        self,
        model,
        tokenizer,
        scanning_type: ScanningType,
        hidden_layers: list = None,
        heads: list = None,
        seq_length: int = None,
        device: str = "cpu",
        verbose: bool = False,
        replacement_token: str = "-",
        batch_size: int = 1,
        use_cache: bool = True,
        combine_causal_effects: bool = True,
        shadow_run: bool = False,
        assume_full_sequence_lm_head: bool = False,
    ):
        super().__init__(
            model,
            tokenizer,
            scanning_type,
            hidden_layers,
            heads,
            seq_length,
            device,
            verbose,
            replacement_token,
            batch_size,
            use_cache,
            combine_causal_effects,
            shadow_run,
            assume_full_sequence_lm_head,
        )

    @torch.no_grad()
    def scan_with_layer_patch(
        self, code: List[str], layer_idx_to_bypass: int
    ) -> BypassedOutput:
        inputs = self._tokenize_code(code)
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask", None)
        return self.scan_with_layer_patch_base(
            input_ids, attention_mask, layer_idx_to_bypass
        )

    @torch.no_grad()
    def scan_layer(self, code: Union[str, List[str]]) -> Tuple[
        Float[torch.Tensor, "n_layers batch_size"],
        Float[torch.Tensor, "n_layers batch_size n_stats"],
    ]:
        if isinstance(code, str):
            code = [code]

        if self.shadow_run:

            seq_len = len(self._tokenize_code(code)["input_ids"][0])

            n_layers_scanned = (
                len(self.hidden_layers)
                if self.hidden_layers
                else self.model.config.num_hidden_layers
            )
            spec = spec_layer_wise(
                batch_size=len(code), seq_len=seq_len, n_layers_scanned=n_layers_scanned
            )
            flops = estimate_scan_flops(
                self.model,
                seq_len=spec.passes[0][1],
                scan_spec=spec,
                assume_full_sequence_lm_head=self.assume_full_sequence_lm_head,
            )
            logger.info(f"[ShadowRun] LayerWise FLOPs ≈ {_human(flops)}")
            return torch.tensor([flops]), torch.empty(0)

        hidden_states, outputs, inputs, attention_matrices = self._llm_forward(code)
        next_token_probs = self._get_next_token_probabilities(hidden_states)
        next_token_probs = next_token_probs.detach().cpu()

        # If no layers specified, default to all except embedding
        if not getattr(self, "hidden_layers", []):
            self.hidden_layers = list(range(len(hidden_states) - 1))

        layer_effects = []
        stats = []

        for layer_idx in range(len(self.hidden_layers)):

            # Scan with layer patching
            patched_output: BypassedOutput = self.scan_with_layer_patch(code, layer_idx)
            if self.verbose:
                logger.info(
                    f"Patched output for layer {layer_idx}: {patched_output.patched_probs.shape}"
                )
            # Compute L2 norm across vocab dim (and any others except batch)
            diff_probs = patched_output.patched_probs - next_token_probs
            fro_per_batch = _fro_norm_per_batch(diff_probs)
            layer_effects.append(fro_per_batch)
            stats.append(
                _stat_features(diff_probs)
            )  # mean, std, range, skewness, kurtosis
        layer_effects: Float[torch.Tensor, "batch_size n_layers"] = torch.stack(
            layer_effects
        )
        stats: Float[torch.Tensor, "batch_size n_layers n_stats"] = torch.stack(stats)

        # Transpose to match the expected output shape
        layer_effects: Float[torch.Tensor, "n_layers batch_size"] = torch.einsum(
            "b l -> l b", layer_effects
        )
        stats: Float[torch.Tensor, "n_layers batch_size n_stats"] = torch.einsum(
            "b l s -> l b s", stats
        )
        return layer_effects, stats

    def scan_function(self, code_entry):
        causal_effects, stats = self.scan_layer(code_entry)
        return causal_effects, stats


if __name__ == "__main__":

    import os

    from transformers import GPT2LMHeadModel, GPT2Tokenizer

    # Define local paths for model and tokenizer
    model_name = "gpt2"
    local_model_path = "./local_gpt2_model"  # Local directory to save/load the model

    # Check if model exists locally, if not download and save it
    if not os.path.exists(local_model_path):
        print("Downloading and saving model locally...")
        model = GPT2LMHeadModel.from_pretrained(model_name, attn_implementation="eager")
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token

        # Save model and tokenizer locally
        model.save_pretrained(local_model_path)
        tokenizer.save_pretrained(local_model_path)
    else:
        print("Loading model from local directory...")
        model = GPT2LMHeadModel.from_pretrained(
            local_model_path, attn_implementation="eager"
        )
        tokenizer = GPT2Tokenizer.from_pretrained(local_model_path)
        tokenizer.pad_token = tokenizer.eos_token

    code_example = [
        "def hello_world():\n    print('Hello, world!')",
        "def goodbye_world():\n    print('Goodbye, world!')",
        "def greet():\n    print(f'Hello, XXXX-1!')",
    ]

    ###### # Example usage for Microsaccades
    scanning_type = ScanningType.MICROSACCADES
    scanner = MicrosaccadesScanner(
        model,
        tokenizer,
        scanning_type=scanning_type,
        device="cpu",
        seq_length=128,
        batch_size=2,
        # shadow_run=True,
        # assume_full_sequence_lm_head=True,  # For testing FLOPs estimation
        # random_positional_encoding=True,  # Example of using random positional encoding
        gaussian_positional_encoding=True,  # Example of using Gaussian positional encoding
    )
    stats, causal_effects = scanner.scan(code_example)

    save_tensor_jsonl(causal_effects, "microsaccades_causal_effects.jsonl")
    loaded_diffs = load_tensors_jsonl("microsaccades_causal_effects.jsonl")[0]

    assert (
        loaded_diffs.shape == causal_effects.shape
    ), "Loaded diffs shape does not match original."

    os.remove("microsaccades_causal_effects.jsonl")  # Clean up the saved file

    # ###### # Example usage for layer-wise scanning

    # scanning_type = ScanningType.LAYER_WISE
    # scanner = LayerWiseScanner(
    #     model,
    #     tokenizer,
    #     scanning_type=scanning_type,
    #     hidden_layers=[0, 1, 2, 3, 4],
    #     device="cpu",
    #     seq_length=128,
    #     batch_size=2,
    #     shadow_run=True,
    #     assume_full_sequence_lm_head=True,  # For testing FLOPs estimation
    # )
    # _, causal_effects = scanner.scan(code_example)
    # save_tensor_jsonl(causal_effects, "layer_causal_effects.jsonl")
    # loaded_diffs = load_tensors_jsonl("layer_causal_effects.jsonl")[0]

    # assert (
    #     loaded_diffs.shape == causal_effects.shape
    # ), "Loaded diffs shape does not match original."

    # os.remove("layer_causal_effects.jsonl")  # Clean up the saved file

    # ###### # Example usage for token-wise scanning

    # scanning_type = ScanningType.TOKEN_WISE
    # scanner = AttentionScanner(
    #     model,
    #     tokenizer,
    #     scanning_type=scanning_type,
    #     heads=[2],
    #     device="cpu",
    #     batch_size=2,
    #     shadow_run=True,
    #     assume_full_sequence_lm_head=True,  # For testing FLOPs estimation
    # )
    # stats, _ = scanner.scan(code_example)

    # save_tensor_jsonl(stats, "token_wise_diffs.jsonl")
    # loaded_diffs = load_tensors_jsonl("token_wise_diffs.jsonl")[0]
    # assert (
    #     loaded_diffs.shape == stats.shape
    # ), "Loaded diffs shape does not match original."

    # os.remove("token_wise_diffs.jsonl")  # Clean up the saved file
