from typing import Optional
from dataclasses import dataclass, field
import transformers
import torch

@dataclass
class OtherArguments:
    model_path: Optional[str] = field(default="./checkpoints/llava-llama-2-7b-chat-lightning-preview")
    model_base: Optional[str] = field(default=None)
    eval_image_folder: Optional[str] = field(default="../POPE/data/minival2014/minival2014")
    question_path: Optional[str] = field(default="../POPE/llava_qa/question")
    question_file: Optional[str] = field(default="I1_sub240_control.json")
    answer_path: Optional[str] = field(default="../POPE/llava_qa/answer")
    answers_file: Optional[str] = field(default=None)

    conv_mode: Optional[str] = field(default="llava_llama_2")
    num_chunks: Optional[int] = field(default=1)
    chunk_idx: Optional[int] = field(default=0)
    cfg: Optional[float] = field(default=None)
    cfg_seed: Optional[int] = field(default=42)
    batch_size: Optional[int] = field(default=16)
    answer_prompter: Optional[bool] = field(default=False)
    tune_manual_mm_projector_path: Optional[str] = field(default=None)
    pretrain_mm_projector_path: Optional[str] = field(default=None)




@dataclass
class ModelArguments:
    model_name_or_path: Optional[str]
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_vision_select_feature: Optional[str] = field(default="patch")


@dataclass
class VisionModuleArguments_with_vm_prefix:
    vm_vision_tower: Optional[str] = field(default=None)
    vm_mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    vm_pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    vm_tune_mm_mlp_adapter: bool = field(default=False)
    vm_mm_use_im_start_end: bool = field(default=False)
    vm_mm_use_im_patch_token: bool = field(default=True)
    vm_mm_vision_select_feature: Optional[str] = field(default="patch")
    vm_pretrain_mm_mlp_adapter: Optional[str] = field(default=None)


# for the sake of compatibility with ModelArguments when using in scripts
# rename all agrs in vm_args that delete vm_ prefix in the name
@dataclass
class VisionModuleArguments:
    vision_tower: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    tune_mm_mlp_adapter: bool = field(default=False)
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_vision_select_feature: Optional[str] = field(default="patch")
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)


@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = 'square'
    image_grid_pinpoints: Optional[str] = field(default=None)
    version: Optional[str] = field(default="llava-llama-2")



@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    freeze_mm_mlp_adapter: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"



def update_bnb_model_from_pretrained_args(bnb_model_from_pretrained_args, training_args):
    if training_args.bits in [4, 8]:
        from transformers import BitsAndBytesConfig
        bnb_model_from_pretrained_args.update(dict(
            device_map={"": training_args.device},
            load_in_4bit=training_args.bits == 4,
            load_in_8bit=training_args.bits == 8,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=training_args.bits == 4,
                load_in_8bit=training_args.bits == 8,
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
                bnb_4bit_compute_dtype=compute_dtype_from_training_args(training_args),
                bnb_4bit_use_double_quant=training_args.double_quant,
                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
            )
        ))
    return bnb_model_from_pretrained_args


def compute_dtype_from_training_args(training_args):
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    return compute_dtype

