from dataclasses import dataclass

@dataclass
class CausalConfig:
    model_name: str = "microsoft/Phi-4-multimodal-instruct" # Backbone model name
    local_dir: str = "third_party/phi_4" # Local directory of the trained model

    # Graph-related configurations
    d_max : int = 10 # Max dimension for A matrix
    rank_r : int = 8 # Rank for low-rank adaptation
    num_graph_tokens : int = 20 # Graph Embedding Tokens (Better n times num_graph_tokens than d_max)
    num_train_epochs : int = 4

    # Model Related Configurations
    lambda_text : float # Weight for text loss
    lambda_consistency : float # Weight for consistency loss

    # Training configurations
    per_device_train_batch_size : int
    per_device_eval_batch_size : int
    gradient_accumulation_steps : int
    lr : float
    weight_decay : float
    wramup_ratio: float
    bf16 : bool = True
    gradient_checkpointing : bool = True

    logging_steps: int = 20

    evaluation_strategy: str
    save_steps: int
    eval_steps: int

    save_strategy: str

    deepspeed: str
    report_to: str
    max_length: int = 3072

    # LoRA configurations   
    # lora_rank_vision: int
    # lora_alpha_vision: int
    # lora_dropout_vision: float
    
    # Backbone Lora configurations: Only train first 0..x layers
    decoder_train_lora_first_upto_layer: int
    # Backbone Lora configurations: Only train last x..n layers
    decoder_train_lora_last_n_layers: int

    # Vision Encoder LoRA: Only train last n layers
    siglip_lora_enable: bool
    siglip_lora_last_n_layers: int
    siglip_lora_include_img_proj: bool

    
    
    
    
    
    
    

    
