"""Configuration dataclasses for experiments."""

from dataclasses import dataclass, field
from pathlib import Path


@dataclass
class EmbeddingConfig:
    model_name: str = "dinov2_vits14"  # dinov2_vits14 (384), dinov2_vitb14 (768), dinov2_vitl14 (1024)
    dataset: str = "imagenet-1k"
    split: str = "train"
    batch_size: int = 128
    num_workers: int = 4
    max_samples: int | None = None  # None = all samples
    output_path: Path = field(default_factory=lambda: Path("embeddings.h5"))
    streaming: bool = False
    cache_dir: str | None = None  # HuggingFace cache location
    device: str = "cuda"


@dataclass
class SAEConfig:
    input_dim: int = 384
    dict_size: int = 4096
    k: int = 16  # TopK sparsity
    learning_rate: float = 1e-3
    batch_size: int = 2048
    num_steps: int = 50_000
    val_every: int = 1000
    device: str = "cuda"


@dataclass
class KSVDConfig:
    dict_size: int = 4096
    nnz_per_col: int = 16  # sparsity level (matches SAE k)
    batch_size: int = 65536  # 2^16
    iters_per_batch: int = 1
    num_repeats: int = 3


DINO_DIMS = {
    "dinov2_vits14": 384,
    "dinov2_vitb14": 768,
    "dinov2_vitl14": 1024,
    "dinov2_vitg14": 1536,
}
