from dataclasses import dataclass
from typing import List, Union
from ...util.hparams import HyperParams
import yaml

config_dict = {
    'mask_ratio': 0.2,
    'edit_lr': 1.,
    'n_iter': 70,
    'norm_constraint': 1.,
    'weight_decay': 1e-5,
    'objective_optimization': 'only_label',
    'act_ratio': 0.88,
    'merge_freq': 1000,
    'save_freq': 500,
    'merge_alg': 'ties',
    'densities': 0.53,
    'weights': 1.0,
    'retrieve': True,
    'inner_params': ['model.layers[27].mlp.down_proj']
}

@dataclass
class WiseConfig:
    # Experiments

    edit_lr: float
    weight_decay: float
    n_iter: int
    # Method
    objective_optimization: str
    mask_ratio: float
    memory_size: int
    act_ratio: float
    merge_freq: int
    retrieve: bool
    save_freq: Union[int, None]
    merge_alg: str
    norm_constraint: float
    # Module templates
    inner_params: List[str]
    weights: Union[float, None]
    densities: Union[float, None]


wise_config = WiseConfig(**config_dict)