from dataclasses import dataclass
from typing import List

from util.hparams import HyperParams


@dataclass
class LORAHyperParams(HyperParams):
    # Method
    layers: List[int]
    num_steps: int
    lr: float
    weight_decay: float
    kl_factor: float
    norm_constraint: float

    # Module templates
    rewrite_module_tmp: str
    layer_module_tmp: str
    mlp_module_tmp: str
    attn_module_tmp: str
    ln_f_module: str
    lm_head_module: str

    lora_type: str
    target_modules: List[str]
    lora_dropout: float
    lora_alpha: int
    rank: int

    # Defaults
    batch_size: int = 128
