from dataclasses import dataclass, field
from typing import List

from hydra.core.config_store import ConfigStore


@dataclass
class WandbConfig:
    """Wandb-specific configuration"""

    entity: str = "???"  # fill your wandb entity
    project: str = "local_adversarial_attack"
    run_name_prefix: str = ""  # Add a prefix to wandb run names, e.g., "dev" or "exp1"


@dataclass
class BlackboxConfig:
    """Configuration for blackbox model evaluation"""

    model_name: List[str] = field(
        default_factory=lambda: ["gpt4v"]
    )  # Can be gpt4v, claude, gemini, gpt_score
    batch_size: int = 1
    timeout: int = 30
    parallel_images: int = 1  # Number of images to process in parallel


@dataclass
class DataConfig:
    """Data loading configuration"""

    batch_size: int = 1
    num_samples: int = 100
    cle_data_path: str = "resources/images/bigscale"
    tgt_data_path: str = "resources/images/target_images"
    output: str = "./img_output"
    retrieval_path: str = "resources/retrieved_embeddings"


@dataclass
class OptimConfig:
    """Optimization parameters"""

    alpha: float = 1.0
    epsilon: int = 8
    steps: int = 300
    optimizer: str = "adam"
    momentum: float = 0.9
    momentum_decay: float = 0.9
    align: str = "pooler"  # can be [pooler or tm] tm for trajectory matching
    tm_idx: List[int] = field(
        default_factory=lambda: [-4, -3, -2, -1]
    )  # the layer index to align with
    beta: float = 0.5  # for pooler_weighted
    use_retrieval: bool = False  # Flag to explicitly enable/disable retrieval
    multi_pass_num: int = 1  # Number of passes for multi-pass attacks


@dataclass
class ModelConfig:
    """Model-specific parameters"""

    input_res: int = 336
    crop_scale: tuple = (0.5, 0.9)
    ensemble: bool = True
    target_crop: bool = False
    device: str = "cuda:0"  # Can be "cpu", "cuda:0", "cuda:1", etc.
    backbone: list = (
        "L336",
        "B16",
        "B32",
        "Laion",
    )  # List of models to use: L336, B16, B32, Laion
    target_num: int = 1  # Number of target crops to align with each source image


@dataclass
class MainConfig:
    """Main configuration combining all sub-configs"""

    data: DataConfig = DataConfig()
    optim: OptimConfig = OptimConfig()
    model: ModelConfig = ModelConfig()
    wandb: WandbConfig = WandbConfig()
    blackbox: BlackboxConfig = BlackboxConfig()
    attack: str = "fgsm"  # can be [fgsm, mifgsm, pgd]


# register config for different setting
@dataclass
class Ensemble3ModelsConfig(MainConfig):
    """Configuration for ensemble_3models.py"""

    data: DataConfig = DataConfig(batch_size=1)
    model: ModelConfig = ModelConfig(backbone=["B16", "B32", "Laion"])
