from dataclasses import dataclass
from typing import Optional
import torch.nn as nn

@dataclass
class TrainerConfig:
    epochs: int
    approach: str
    lr: float
    seed: int = 42
    wandb: bool = False
    res_dir: str = 'checkpoints'
    weight_decay: float = 0.0001
    early_stopping: bool = True
    patience: int = 10

@dataclass
class DatasetConfig:
    data_dir: str
    batch_size: int = 128
    enc_type: str = 'one_hot_padded'
    num_workers: int = 4
    unconf_split: bool = False
    partial_conf_ratio: float = 0.0
    partial_conf_dir: Optional[str] = None

@dataclass
class BooleanConfig:
    save_model: bool = False
    save_confusion_matrix: bool = False

@dataclass
class ModelConfig:
    model: str
    pretrained_model: bool = False
    pretrained_path: Optional[str] = None
    n_heads: Optional[int] = None
    set_transf_hidden: Optional[int] = None
    hidden_dim: Optional[int] = None
    dropout: Optional[float] = None

@dataclass
class FeatureSelectorConfig:
    segmentation_method: str = 'topk'
    mask_size: int = 6
    lr_merlin: float = 0.001
    lr_morgana: float = 0.001
    lr_fs: Optional[float] = None
    gamma: float = 1.0
    l1_penalty_coefficient: float = 0.1
    sfw_max_iterations: int = 350
    sfw_patience: int = 10
    fs_model: str = 'settransformer'
    fs_hidden_dim: int = 128
    fs_dropout: float = 0.1
    fs_n_heads: int = 4
    weight_decay_merlin: float = 0.0001
    weight_decay_morgana: float = 0.0001
    weight_decay_fs: Optional[float] = None
    feature_distribution: bool = True
    compute_prec_and_ent: bool = False
    compute_avg_occ: bool = False
    feat_interp_ncb_s0: bool = True