from typing import Dict, Optional, Sequence, List
from dataclasses import dataclass, field
import transformers

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    version: Optional[str] = field(default="v0")
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    protein_tower: Optional[str] = field(default=None)
    mm_protein_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='linear')
    mm_use_prot_start_end: bool = field(default=False)
    mm_use_prot_patch_token: bool = field(default=True)
    mm_protein_select_feature: Optional[str] = field(default="patch")
    mm_protein_tower: Optional[str] = field(default=None)
    use_mm_proj: bool = field(default=True)
    mm_projector_type: Optional[str] = field(default='linear')
    protein_max_len: int = field(default=256)
    tune_norm_layer: bool = field(default=False)
    
    residual_dropout: float = field(default=0.2)


@dataclass
class DataArguments:
    # data_path: str = field(default=None,
    #                        metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    go_term_graph: str = field(default='datasets/GOA_Human/go.obo',
                           metadata={"help": "Path to the go term graph file, the *.odo file."})
    protein_pkl: str = field(default='datasets/GOA_Human/train_data_fold_0.pkl',
                           metadata={"help": "Path to the protein file, the *.pkl file."})
    max_protein_length: int = field(
        default=256,
        
    )
    
    
    eval_protein_pkl: str = field(default='datasets/GOA_Human/validation_data_fold_0.pkl',
                           metadata={"help": "Path to the protein file, the *.pkl file."})
    max_labels = 5
    training_csv_file: str = field(
        default='datasets/split100.csv',
        metadata={"help": "Path to the ec number csv file, contains the ec number-sequence pairs."}
    )
    eval_csv_file: str = field(
        default='datasets/new.csv',
        metadata={"help": "Path to the ec number csv file, contains the ec number-sequence pairs."}
    )
    test_csv_files: str = field(
        default="datasets/halogenase.csv,datasets/multi.csv,datasets/price.csv,datasets/new.csv"
    )
    enzclass_file: str = field(
        default="datasets/enzclass_dict.pkl"
    )
    with_comments: bool = False
    comments_file: str = field(
        default='datasets/ec_description.json'
    )
    query_template: str = field(
        default='datasets/query_template.json'
    )
    
    retrieval_mmseq: bool = field(
        default=False
    )
    
    retrieval_smiles: bool = field(
        default=False
    )
    ecnumber_rheadid_file: str = field(
        default='datasets/ecnumber_rheaid_mapping.json'
    )
    
    smiles_path: str = field(
        default='datasets/SMILES'
    )
    # image_folder: Optional[str] = field(default=None)
    # image_aspect_ratio: str = 'square'


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    freeze_mm_mlp_adapter: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"
    mm_projector_lr: Optional[float] = None
    group_by_modality_length: bool = field(default=False)
    output_dir: Optional[str] = field(default='model_ckpt')
    
    is_test: bool = field(
        default=False
    )