from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional

import torch
from jaxtyping import Float


class ScanningType(Enum):
    """Enum for different scanning types."""

    LAYER_WISE = "layer_wise"
    """Layer-wise scanning."""
    TOKEN_WISE = "token_wise"
    MICROSACCADES = "microsaccades"


@dataclass
class BypassedOutput:
    """Dataclass to hold the output of a bypassed layer."""

    input_ids: Float[torch.Tensor, "batch seq_len"]
    """Input IDs for the code."""
    patched_probs: Float[torch.Tensor, "batch seq_len vocab_size"]
    """Patched probabilities after bypassing the layer."""
    all_hidden_states: Dict[str, Float[torch.Tensor, "batch seq_len hidden_size"]]
    """All hidden states from the model."""
    all_attentions: Optional[
        Dict[str, Float[torch.Tensor, "batch num_heads seq_len seq_len"]]
    ] = None
