from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List

import transformers

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    version: Optional[str] = field(default="v0")
    freeze_backbone: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(default=-1)  # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='linear')
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_patch_merge_type: Optional[str] = field(default='flat')
    mm_vision_select_feature: Optional[str] = field(default="patch")
    inter_contrastive: bool = field(default=False)
    intra_contrastive: bool = field(default=False)
    doc_model_init: bool = field(default=True)

@dataclass
class DataArguments:
    train_query_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    eval_query_path: str = field(default=None,
                                 metadata={"help": "Path to the validation data."})
    document_path: str = field(default=None,
                                 metadata={"help": "Path to the document KB data."})
    dataset_id_to_path: str = field(default=None,
                                 metadata={"help": "Path to the mapping function that maps query image id to path."})
    image_url_to_id_path: str = field(default=None,
                                      metadata={"help": "Path to the mapping function that maps document image url to path."})
    query_use_image: bool = field(default=True, metadata={"help": "Use image as query input"})
    doc_use_image: bool = field(default=True, metadata={"help": "Use images for interleaved documents"})
    doc_use_table: bool = field(default=True, metadata={"help": "Use tables for interleaved documents"})
    mixed_query_modality: bool = field(default=False, metadata={"help": "Use both text-only query and multimodal query"})
    is_multimodal: bool = False
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = 'square'
    concat_sec: bool = field(default=False, metadata={"help": "For concatenating query with each section"})
    debugging: bool = field(default=False, metadata={"help": "Debugging mode to take only 50 data"})
    train_subset: bool = field(default=False, metadata={"help": "Subset experiment that uses only N% of the total data"})
    train_subset_ratio: float = field(default=1.0, metadata={"help": "Subset experiment subset ratio"})
    dataset_name: str = field(default=None,
                              metadata={"help": "For loading datasets"})
    single_img: bool = field(default=False, metadata={"help": "Only the first image can be used in the document"})
    only_summary: bool = field(default=False, metadata={"help": "Only the first section can be referred"})    
    subset_sec_num: int = field(default=4,
                                metadata={"help": "The number of sections in a subset document."})
    only_entity: bool = field(default=False, metadata={"help": "Only the title can be referred"})
    passage_result_path: str = field(default=None,
                              metadata={"help": "Loading hard negatives"})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
                "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"
    mm_projector_lr: Optional[float] = None