# autointerp_hf/config.py
from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class AutoInterpEvalConfig:
    """
    Configuration controlling the end-to-end auto-interpretability pipeline.

    This is inspired by the original AutoInterpEvalConfig from saebench, but modified
    for a pure-HuggingFace / custom-hook setup.

    Key responsibilities:
    - controls which latents we try to interpret (n_latents / override_latents)
    - controls how we sample examples for generation & scoring
    - controls judge model settings (done elsewhere, but this config holds some limits)
    """

    # -------------------------
    # Model + SAE hooks
    # -------------------------
    model_name_or_path: str           # e.g. "/path/to/local/Qwen3-1.7B-Base" or "Qwen/Qwen2.5-7B"
    hook_module_path: str             # e.g. "model.layers.20" or "auto"
    device: str = "cuda"              # e.g. "cuda", "cuda:0", or "cpu"
    torch_dtype: str = "auto"         # "float32","bfloat16","float16","auto"

    # -------------------------
    # Latent selection
    # -------------------------
    n_latents: Optional[int] = 100     # how many latents to evaluate
    override_latents: Optional[List[int]] = None
    dead_latent_threshold: float = 15  # minimum activation count for a latent to be considered "alive"
    random_seed: int = 42

    # -------------------------
    # Dataset / tokenization
    # -------------------------
    dataset_name: str = "common-pile/comma_v0.1_training_dataset"
    heldout_jsonl: Optional[str] = None   # local .jsonl path; if set, dataset_name is ignored
    heldout_text_key: str = "text"        # field name in jsonl rows (or HF rows)
    skip_first_n_examples: int = 0        # skip N rows for a pseudo-heldout region
    total_tokens: int = 2_000_000          # total tokens to gather from dataset
    llm_context_size: int = 128            # context length for each row in the tokenized dataset
    batch_size: Optional[int] = None       # batch size for collecting activations; user can override

    # -------------------------
    # Windowing / sampling
    # -------------------------
    buffer: int = 10                       # number of tokens to the left/right of each "center" token
    no_overlap: bool = True                # avoid overlapping windows for top-k picks
    act_threshold_frac: float = 0.01       # fraction * max_activation used to mark which tokens are "active"

    # -------------------------
    # Judge / scoring toggles
    # -------------------------
    scoring: bool = True
    max_tokens_in_explanation: int = 30    # max new tokens judge can generate for explanation
    use_demos_in_explanation: bool = True  # include demo explanations in the system prompt

    # -------------------------
    # Example counts
    # -------------------------
    n_top_ex_for_generation: int = 10        # top activations used for explanation prompt
    n_iw_sampled_ex_for_generation: int = 5  # importance-weighted "medium" activations for explanation

    n_top_ex_for_scoring: int = 2          # leftover top acts for scoring
    n_random_ex_for_scoring: int = 10      # random windows for negative examples
    n_iw_sampled_ex_for_scoring: int = 2   # leftover medium activations for scoring

    # A computed field: which latents to interpret in this run
    latents: Optional[List[int]] = field(default=None, init=False)

    def __post_init__(self):
        # If override_latents is provided, we use those directly.
        if self.override_latents is not None:
            self.latents = list(self.override_latents)
        else:
            self.latents = None

        # If batch_size is None, choose a reasonable default.
        if self.batch_size is None:
            # A small-ish default that is often safe for a 7B model.
            self.batch_size = 32

    @property
    def n_ex_for_generation(self) -> int:
        """
        Total number of examples used for generation prompt (explanation).
        """
        return self.n_top_ex_for_generation + self.n_iw_sampled_ex_for_generation

    @property
    def n_ex_for_scoring(self) -> int:
        """
        Total number of examples used for scoring prompt (prediction).

        We'll show:
          - n_top_ex_for_scoring "high activation" windows (positive)
          - n_iw_sampled_ex_for_scoring "medium" activation windows (positive-ish)
          - n_random_ex_for_scoring random windows (negative-ish)

        We'll compare to ground truth to compute accuracy.
        """
        return (
            self.n_top_ex_for_scoring
            + self.n_iw_sampled_ex_for_scoring
            + self.n_random_ex_for_scoring
        )

    @property
    def max_tokens_in_prediction(self) -> int:
        """
        Prediction responses are short: comma-separated integers, maybe 'None'.
        We give some margin here.
        """
        return 2 * self.n_ex_for_scoring + 5
