from typing import Dict, Optional, Sequence, List, Union
import transformers
from dataclasses import dataclass, field


@dataclass
class ModelArguments:
    # vlm_model_path: str = "/vsphhome/xwx/med-MLLM/baseline_merging/task_vectors/checkpoints/qwen2_scal0.5"
    vlm_model_path: str = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
    peft_model_path: str = None
    quant_bit: int = -1
    quant_type: str = "nf4"
    checkpoint: str = None
    hook_model: bool = False
    ##### for r-lora #####
    reinit: bool = field(
        default=True,
        metadata={"help": "Whether to reinit the lora layer"},
    )
    rank_stablization: bool = field(
        default=True,
        metadata={"help": "Whether to use rank stablization for scaling"},
    )    
    lora_A_init: Optional[str] = field(
        default="unit",
        metadata={"help": "Initialization method for lora_A",
        "choices": ["gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"]
        }
    )
    lora_B_init: Optional[str] = field(
        default="unit",
        metadata={"help": "Initialization method for lora_B",
        "choices": ["gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"]
        }
    )
    init_scale: Optional[str] = field(
        default="stable",
        metadata={"help": "Scaling method for initialization"}
    )
    stable_gamma: Optional[int] = field(
        default=64,
        metadata={"help": "Gamma value for stable scaling"}
    )
    init_bs: Optional[int] = field(
        default=5,
        metadata={"help": "Batch size for gradient initialization"}
    )
    ##### for r-lora ######


@dataclass
class DataArguments:
    data_dir: List[str] = field(default_factory=lambda: ["slake_vqa"])


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    output_dir: str = "v2/output"
    use_peft: bool = True
    peft_type: str = "lora"  ## lora, ada-lora, vb-lora, vera, prefix-tuning, prompt-tuning, ia3
    lora_rank: int = 8
    p_tuning_token: int = 256
    my_lora: bool = False
    neighbor_gap: int = 1
    r_scaling: int = 6
    householder_dim: int = 4
    rotation_angle: float = 1
    lora_target_modules: Union[List[str], str] = field(default_factory=lambda: [
            # "q_proj",
            # "v_proj",
            # "k_proj",
            "down_proj"
        ])
    requires_grad_list: List[str] = field(default_factory=lambda: ["multi_modal_projector"])
    gradient_checkpointing: bool = False
    logging_steps: int = 16
    logging_strategy: str = "steps"
    exp_name: str = "lora"
    do_train: bool = True
    do_eval: bool = True
    num_train_epochs: float = 3.0
    per_device_eval_batch_size: int = 1
    per_device_train_batch_size: int = 1
    gradient_accumulation_steps: int = 8
    # per_device_eval_batch_size: int = 6
    # per_device_train_batch_size: int = 6
    # gradient_accumulation_steps: int = 4
    eval_accumulation_steps: int = 1
    warmup_steps: int = 100
    save_total_limit: int = 1
    save_strategy: str = "epoch"
    bf16: bool = True
    # fp16: bool = True
    seed: int = 42
    deepspeed: str = "v2/config/zero2.json"
    remove_unused_columns: bool = False
    save_only_model: bool = True
