from dataclasses import dataclass, field
import json

@dataclass
class AsrLLMConfig:
    name: str = "asr-llm"
    speech_encoder_type: str = "whisper"
    speech_encoder_path: str = None
    llm_path: str = None
    use_flash_attn: bool = True
    stage: int = 1
    pretrained_stage1_model_path: str = None
    lora_rank: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: list = field(default_factory=list)
    task_type: str = "CAUSAL_LM"
    encoder_projector_ds_rate: int = 8
    subsampling_factor: int = 2
    
    def to_dict(self):
        """Convert config to dictionary."""
        return self.__dict__

    def save_to_json(self, json_path: str):
        """Save config to a JSON file."""
        with open(json_path, "w") as f:
            json.dump(self.to_dict(), f, indent=4)   