from dataclasses import dataclass, field
from transformers import TrainingArguments
from typing import List


@dataclass
class ModelArguments:
    model_name: str = field(metadata={"help": "huggingface model name or path"})
    model_type: str = field(default=None, metadata={"help": "model type, typically includes in config file, but sometimes needs mannually add"})
    processor_name: str = field(default=None, metadata={"help": "processor_name, huggingface model name or path"})
    model_backbone: str = field(default=None, metadata={"help": "HF model type"})
    checkpoint_path: str = field(default=None, metadata={"help": "a local model path, could be a LoRA version"})
    pooling: str = field(default='last', metadata={"help": "pooling method for encoder"})
    normalize: bool = field(default=False, metadata={"help": "normalize query and passage representations"})
    temperature: float = field(default=0.02, metadata={"help": "temperature for softmax"})
    lora: bool = field(default=False, metadata={"help": "do parameter-efficient fine-tuning with lora"})
    lora_r: int = field(default=16, metadata={"help": "lora r"})
    lora_alpha: int = field(default=64, metadata={"help": "lora alpha"})
    lora_dropout: float = field(default=0.1, metadata={"help": "lora dropout"})
    lora_target_modules: str = field(default="qkv_proj,o_proj,gate_up_proj,down_proj,k_proj,q_proj,out_proj,v_proj", metadata={"help": "lora target modules"})
    num_crops: int = field(default=16, metadata={"help": "number of crops used in image encoder"})
    uigraph_use: bool = field(default=False, metadata={"help": "Enable ui graph for token selection"})
    uigraph_diff: int = field(default=1, metadata={"help": "Pixel difference used for constructing ui graph for token selection"})
    uigraph_rand: bool = field(default=False, metadata={"help": "Enable random graph construction for token selection"})
    uimask_ratio: float = field(default=0.5, metadata={"help": "Specify the percentage of patch tokens to skip per component for token selection"})
    uimask_rand: bool = field(default=False, metadata={"help": "Enable random token selection instead of uniform selection"})
    lm_skip_layer: str = field(default='[1,28,0]', metadata={"help": "Specify the layers of the language model to skip for token selection"})
    vis_skip_layer: str = field(default='[1,32,0]', metadata={"help": "Specify the layers of the vision model to skip for token selection"})
    #! new args
    init_lora_model: bool = field(default=False, metadata={"help": "initializing with lora model"})
    # distiller args:
    teacher_backbone: str = field(default=None, metadata={"help": "teacher model backbone"})
    teacher_model_name: str = field(default=None, metadata={"help": "teacher model name or path"})
    teacher_lora: bool = field(default=False, metadata={"help": "whether teacher is lora"})
    teacher_lora_r: int = field(default=16, metadata={"help": "teacher lora r"})
    teacher_lora_alpha: int = field(default=64, metadata={"help": "teacher lora alpha"})
    teacher_lora_dropout: float = field(default=0.1, metadata={"help": "teacher lora dropout"})
    teacher_lora_target_modules: str = field(default="qkv_proj,o_proj,gate_up_proj,down_proj,k_proj,q_proj,out_proj,v_proj", metadata={"help": "teacher lora target modules"})
    teacher_pooling: str = field(default='last', metadata={"help": "pooling method for teacher encoder"})
    teacher_normalize: bool = field(default=False, metadata={"help": "normalize query and passage representations for teacher"})
    projector_config_path: str = field(default=None, metadata={"help": "projector config path, if None, no projector will be used"})
    projector_path: str = field(default=None, metadata={"help": "projector model path, if None, no projector will be used"})
    projector_lr: float = field(default=1e-4, metadata={"help": "projector learning rate"})
    student_hidden_dim: int = field(default=896, metadata={"help": "student hidden dim"})
    teacher_hidden_dim: int = field(default=1536, metadata={"help": "teacher hidden dim"})
    load_pretrained_lora: bool = field(default=False, metadata={"help": "load pretrained lora model for student"})
    #! new args for span loss
    
    

@dataclass
class DataArguments:
    dataset_config: str = field(default=None, metadata={"help": "yaml file with dataset configuration"})
    dataset_name: str = field(default=None, metadata={"help": "huggingface dataset name"})
    subset_name: List[str] = field(default=None, metadata={"help": "Useful for datasets with subsets"})
    dataset_split: str = field(default='train', metadata={"help": "dataset split"})
    num_sample_per_subset: int = field(default=None, metadata={"help": "number of training samples per subset"})
    image_dir: str = field(default=None, metadata={"help": "Image directory path"})
    encode_output_path: str = field(default=None, metadata={"help": "encode output path"})
    max_len: int = field(default=None, metadata={"help": "The maximum total input sequence length after tokenization. Use with caution, since it may truncate text prompts due to large image lengths."},)
    embedding_type: str = field(default="", metadata={"help": "embedding type"})
    image_resolution: str = field(default=None, metadata={"help": "for models i.e. LLaVA-next and Qwen, resize images first, none means using original image resolution. This is only works when `--resize_use_processor false`."})
    resize_use_processor: bool = field(default=False, metadata={"help": "Resize visual inputs insides processor, e.g. Qwen2VLImageProcessor, instead of by our code."})
    resize_min_pixels: int = field(default=28*28*4, metadata={"help": "The min pixels of the image to resize the image. This is only works when `--resize_use_processor true`."})
    resize_max_pixels: int = field(default=28*28*1280, metadata={"help": "The max pixels of the image to resize the image. This is only works when `--resize_use_processor true`."})
    image_decay_factor: float = field(default=None, metadata={"help": "The image decay factor for resizing temporal images"})
    num_hardneg: int = field(default=0, metadata={"help": "hard negative number"})
    #! new args
    sdibn: bool = field(default=False, metadata={"help": "huggingface model name"})
    odibn: bool = field(default=False, metadata={"help": "huggingface model name"})
    rdibn: bool = field(default=False, metadata={"help": "huggingface model name"})
    tgt_prefix_mod: bool = field(default=False, metadata={"help": "Modify the pos_prefix"})
    chunk_size: int = field(default=32, metadata={"help": "Cluster sizes in metis. Only used in odibn"})
    #!new args 2
    eval_dataset_name: str = field(default=None, metadata={"help": "Useful for datasets with subsets"})
    eval_subset_name: List[str] = field(default=None, metadata={"help": "Useful for datasets with subsets"})
    eval_image_dir: str = field(default=None, metadata={"help": "Eval Image directory path"})
    pos_only: bool = field(default=False, metadata={"help": "Only use positives"})
    # new args distillation
    percent_data: float = field(default=1.0, metadata={"help": "percentage of data used for distillation training"})
    


@dataclass
class TrainingArguments(TrainingArguments):
    image_encoder_freeze: bool = field(default=False, metadata={"help": "huggingface model name"})
    output_dir: str = field(default=None, metadata={"help": "directory for saving trained models"})
    resume_from: str = field(default="none", metadata={"help": "`auto` will detect if any previous checkpoints should be resumed. or specify specific step of the checkpoint."})
    project_name: str = field(default=None, metadata={"help": "project name"})
    logging_steps: int = field(default=1, metadata={"help": "logging steps"})
    num_train_epochs: int = field(default=1, metadata={"help": "number of training epochs"})
    grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
    gc_q_chunk_size: int = field(default=128, metadata={"help": "query side subset size. Should be power of 2"})
    gc_p_chunk_size: int = field(default=128, metadata={"help": "target side subset size. Should be power of 2"})
    interleave_stopping_strategy: str = field(default="all_exhausted", metadata={"help": "all_exhausted or first_exhausted"})
    interleave_batch_size: float = field(default=0, metadata={"help": "Specify mini-batch size to interleave data from multi-sources, 0/None means random sampling by examples, 1 means full batch."})
    #!new args
    gc_dynamic_limit: int = field(default=125, metadata={"help": "gc_chunk default limit - (128, 125) sized matrices works for Qwen2b. gc_dynamic_limit would be 125 and gc_p|q_chunk_size would be 128"})
    #!new kd loss weight
    kd_weight: float = field(default=0.01, metadata={"help": "weight of kd loss in total loss"})
    rkd_distance_weight: float = field(default=1.0, metadata={"help": "weight of distance loss in total kd loss"})
    rkd_angle_weight: float = field(default=2.0, metadata={"help": "weight of angle loss in total kd loss"})
    kd_loss_type: str = field(default="contrastive_rkd", metadata={"help": "type of kd loss, current only support RKD"})
    ds_config: str = field(default=None, metadata={"help": "DeepSpeed config json file path"})
    deepspeed_config: str = field(default=None, metadata={"help": "DeepSpeed config json file path"})
    # new args for span loss
    teacher_layer_mapping: List[int] = field(
        default_factory=list,
        metadata={"help": "List of teacher layers used for distillation; number of elements equals number of projectors"}
    )
    student_layer_mapping: List[int] = field(
        default_factory=list,
        metadata={"help": "List of student layers used for distillation; number of elements equals number of projectors"}
    )
    split_layer_mapping: List[int] = field(
        default_factory=list,
        metadata={"help": "List of split layers for student; number of elements equals number of projectors"}   
    )
    w_cross_modal_loss: float = field(default=1.0, metadata={"help": "weight for cross modal loss"})
@dataclass
class MTEBArguments:
    device: str = field(default="cuda", metadata={"help": "use cuda for single GPU inference, if multiple GPUs are available it will use DP automatically"})
    batch_size_per_device: int = field(default=16, metadata={"help": ""})
    max_length: int = field(default=512, metadata={"help": ""})
    eval_output_dir: str = field(default=None, metadata={"help": "directory for saving trained models"})
    task_types: List[str] = field(default=None, metadata={"help": ""})
    tasks: List[str] = field(default=None, metadata={"help": ""})
    prompt_family: List[str] = field(default=None, metadata={"help": ""})
