from dataclasses import dataclass, field
from typing import Optional, Union, List, Literal
from peft.config import PeftConfig
from peft.utils import PeftType

@dataclass
class GroupLoraConfig(PeftConfig):

    target_modules: Optional[Union[List[str], str]] = field(
        default=None, metadata={"help": "List of module names to replace with LoRA"})
    r: int = field(default=32, metadata={"help": "GroupLora attention dimension"})
    alpha: int = field(default=32, metadata={"help": "GroupLora alpha alpha"})
    dropout: float = field(default=0.05, metadata={"help": "GroupLora dropout dropout"})

    fan_in_fan_out: bool = field(default=False, metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"})
    bias: Literal["none", "all", "lora_only"] = field(
        default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}
    )
    group_size: int = field(default=4, metadata={"help": "group_size"})
    num_layers: int = field(default=12, metadata={"help": "Number of transformer layers"})

    modules_to_save: Optional[list[str]] = field(
        default=None,
        metadata={"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
                    "For example, in Sequence Classification or Token Classification tasks, "
                    "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
        })

    def __post_init__(self):
        self.peft_type = PeftType.GROUPLORA