from dataclasses import dataclass
from typing import Optional

@dataclass
class TrainerConfig:
    epochs: int
    approach: str
    lr: float
    weight_decay: float = 0.00001
    seed: int = 42
    res_dir: str = None
    wandb: bool = False
    early_stopping: bool = True
    patience: int = 10

@dataclass
class DatasetConfig:
    data_dir: str
    batch_size: int
    num_workers: int
    unconf_split: bool = False
    num_classes: int = 3

@dataclass
class ModelConfig:
    model: str
    pretrained_model: bool
    pretrained_path: str

@dataclass
class BooleanConfig:
    save_model: bool = False

@dataclass
class FeatureSelectorConfig:
    mask_size: int 
    lr_merlin: float = 0.1
    lr_morgana: float = 0.1
    weight_decay_fs: float = None
    weight_decay_merlin: float = 0.00001
    weight_decay_morgana: float = 0.00001
    gamma: float = 1
    l1_penalty_coefficient: float = 1
    l2_penalty_coefficient: float = 0
    tv_penalty_coefficient: float = 0
    sfw_max_iterations: int = 350
    sfw_patience: int = 10
    lr_fs:Optional[float] = None
    unet_steps: int = 1
