from dataclasses import dataclass


@dataclass
class ModelConfig:
    model: str
    likelihood: str
    latent_dim: int
    hidden_dim: int
    num_layers: int
    penalty_scale: float
    entropy_relative_scale: float
    penalty_exp_factor: float
    cvae_penalty: str
    kl_beta: float
    use_batchnorm: bool
    learn_sigma: bool
    bandwidth: float
    penalise_z: bool
    rbf_version: int


@dataclass
class TrainConfig:
    batch_size: int
    num_epochs: int
    learning_rate: float
    gamma: float


@dataclass
class TestConfig:
    categories: list
    condition: str
