import os
import sys
import torch

from sacred import Experiment

root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)

from mllmsd.utils.util import map_name_task, get_short_name

ex = Experiment("METER", save_git_info=False)

@ex.config
def config():
    # Model config
    drf = "llava-hf/llava-1.5-7b-hf" # drf model: "lmms-lab/llava-onevision-qwen2-0.5b-ov", "llava-hf/llava-1.5-7b-hf"  # "InternVL2-2B"
    tgt = "llava-hf/llava-1.5-7b-hf" # target model: "llava-hf/llava-1.5-13b-hf", "InternVL2-2B"
    # drf = "llava-hf/llava-interleave-qwen-0.5b-hf"
    # tgt = "llava-hf/llava-interleave-qwen-7b-hf" 
    captioning_model = "microsoft/Florence-2-large-ft"  # "Salesforce/blip2-opt-2.7b-coco", "Salesforce/blip-image-captioning-base" "ljnlonoljpiljm/florence-2-large-llava-recap-cc3m" "ljnlonoljpiljm/florence-2-large-llava-recap-cc3m"
    caption_type = "<MORE_DETAILED_CAPTION>" # "<DETAILED_CAPTION>",  ""
    
    drf_dtype = "fp16" # torch.float32
    tgt_dtype = "fp16" # torch.float32
    captioning_model_dtype = "fp16"
    caption_prefix = "image: "
    assistant_dtype = "fp32"
    drafting = 'multimodal' # [’multimodal’, ‘text-only’, ‘caption’, ‘tokenized-image’, ‘special-token‘, image-pool’]
    image_top_k_attention = 0 # llama - 576, qwen - 729
    is_drf_text_only = drafting in ['text-only', 'special-token', 'caption', 'tokenized-image']
    is_tgt_text_only = False
    is_drf_from_mllm = True
    drf_aux_tokenizer = None # for base lm

    # drafting
    target_dim_image_pooling = 1 # which spatial dimension the pooled image would be
    image_pool_type = 'avg2d'
    output_image_attentions = False
    logging_top_k = 5

    """
    [
        "JackFram/llama-68m", "JackFram/llama-160m"
        "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-1.5-13b-hf",
        "llava-hf/llava-interleave-qwen-0.5b-hf", "llava-hf/llava-interleave-qwen-7b-hf", 
        "InternVL2-2B", "lmms-lab/llava-onevision-qwen2-0.5b-ov",
        "lmms-lab/llava-onevision-qwen2-0.5b-ov", "llava-hf/llava-interleave-qwen-7b-hf", 
        "InternVL2-2B"
    ]
    """

    # draft generation config
    max_prompt_length = 2048 # len(x)
    max_target_length = 128
    min_target_length = None
    assistant = None
    temperature = .0 # "temperature for sampling during drf_model.generate")
    cascade_rule = 'mm-weight' # ['confidence', 'mm-weight']
    mm_weight_policy = None
    mm_weight_k = None
    
    # Decoding config
    decoding = 'sd' # ['sd', 'ard']
    max_chunk_length = 5
    solution_type = 'sol1'
    
    # Experiment config
    batch_size = 1
    seed = 2024
    debug = True
    test_only = True
    do_print = False
    save_steps = 2500
    is_time_factorized = False
    metric = [
        'sequences', 
        'num_prompt_tokens', 
        'num_accepted_tokens', 
        'num_prefill_tokens_drf', 
        'num_prefill_tokens_tgt', 
        'ids_accepted_tokens', 
        'ids_first_rejected', 
        'tokens_first_rejected', 
        'time_total', 
        'time_prefill_drf', 
        'time_prefill_tgt', 
        'time_generate_drf',
        'time_verify_tgt',
        'time_prompt_process',

        'tokens_accepted_tokens_topk',
        'value_probability_accepted_topk',
        'tokens_rejected_tokens_topk',
        'value_probability_rejected_topk',

        "value_image_attention_drf_accepted", # num_samples x num chunks x num accepted X num layers x num heads x query len (=1) x key len (num image tokens)
        "value_image_attention_drf_first_rejected",
        "ids_image_attention_drf_accepted",
        "value_probability_ratio_accepted",
        "value_probability_ratio_first_rejected",
        'value_probability_accepted_drf',
        'value_probability_accepted_tgt',
        'value_probability_first_rejected_drf',
        'value_probability_first_rejected_tgt',
    ]

    # dataset config
    dataset = "LLaVA-Instruct-150K"
    # ["LLaVA-Instruct-150K", "COCO2014", "ScienceQA"]
    # ["VibeEval", "DC100_EN", "LLaVA-Bench-Wilder"]
    # ['Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'HQ-Edit', 'MagicBrush', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA']
    eval_models = [None]
    eval_datasets = [None]
    eval_is_drf_text_only = [None]
    eval_drafting = [None]
    eval_max_chunk_length = [None]
    eval_target_dim_image_pooling = [None]
    eval_captioning_model = [None]
    eval_caption_type = [None]
    eval_image_top_k_attention = [None]
    eval_mm_weight_policy = [None]
    eval_cascade_rule = [None]
    eval_mm_weight_k = [None]
    exp_title = ''
    empty_cache = None

    tiny_data = False # use small data for debugging
    reduce_data = None # reduce data for efficiency (llava test set: 15772)

    # Train classifier
    train_classifier = False
    load_classifier = False
    classifer = 'route' # 'route', 'binary'
    classifier_top_k = None
    classifier_batch_size = 32
    classifier_lr = 1e-5
    classifier_epochs = 100
    classifier_arch = 'small' # 'mini', 'small'
    classifier_vocab = 'top-k' # 'full', 'top-k'
    classifier_ckpt_path = None

    # Logging config
    wandb_project_name = "MLLMSD" # wandb project name

    # Path config
    root = "/XXXX-5/home-XXXX-3" # root path 
    input_datasets_dir = f"{root}/data/MSD/datasets/{dataset}"
    ckpt_dir = None # load checkpoint
    npy_save_dir = ""
    # npy_save_dir = f"{root}/data/MSD/npy/{ckpt_save}"

@ex.named_config
def TrainClassifier():
    train_classifier = True
    classifier_top_k = 5
    classifier_batch_size = 128
    classifier_lr = 5e-5
    classifier_epochs = 5000
    classifier_arch = 'small' # 'mini', 'small'
    classifier_vocab = 'top-k' # 'full', 'top-k'

    exp_title="fp16-mm-weight-cascade-MTC-A6000-CAPTION-240924"
    drf="XXXX-2/lvlm68m"
    tgt="llava-hf/llava-1.5-7b-hf"
    drafting=['multimodal', 'text-only', 'caption']
    caption_type="<CAPTION>"
    captioning_model="microsoft/Florence-2-large-ft"
    cascade_rule="mm-weight"
    mm_weight_policy=1

@ex.named_config
def TestClassifier():
    train_classifier = True
    load_classifier = True
    classifier_top_k = 5
    classifier_batch_size = 128
    classifier_lr = 5e-5
    classifier_epochs = 5000
    classifier_arch = 'small' # 'mini', 'small'
    classifier_vocab = 'top-k' # 'full', 'top-k'

    classifier_ckpt_path = "/XXXX-5/home-XXXX-3/data/MSD/checkpoint/classifier/sd_llava-68m_llava-llama-7b_caption-florence2-0.77b-C-multimodal-text-only-mm-weight-1-cascade-drafting_DC100_EN_mtl-128_gamma-5_t0_fp16-16_2024/5e-05-5000-5.pth"
    # sd_llava-68m_llava-llama-7b_caption-florence2-0.77b-C-multimodal-text-only-mm-weight-1-cascade-drafting_DC100_EN_mtl-128_gamma-5_t0_fp16-16_2024
    # sd_llava-68m_llava-llama-7b_caption-florence2-0.77b-C-multimodal-text-only-mm-weight-1-cascade-drafting_DC100_EN_mtl-128_gamma-5_t0_fp16-16_2024/5e-05-5000-5.pth"

    drf="XXXX-2/lvlm68m"
    tgt="llava-hf/llava-1.5-7b-hf"
    drafting=['multimodal', 'text-only', 'caption']
    caption_type="<CAPTION>"
    captioning_model="microsoft/Florence-2-large-ft"
    cascade_rule="mm-weight"
    mm_weight_policy=1


@ex.capture
def capture_config(_config):
    return _config

@ex.named_config
def Llama68m():
    drf = "XXXX-2/lm68m"
    # drf = "JackFram/llama-68m"
    # is_drf_from_mllm = False
    is_drf_text_only = True
    # exp_title = 'JackFram'


# @ex.named_config
# def Llama160m():
#     drf = "JackFram/llama-160m"
#     is_drf_from_mllm = False
#     is_drf_text_only = True
#     exp_title = 'JackFram'

@ex.named_config
def Llama290m():
    drf = "XXXX-2/lm290m"
    # is_drf_from_mllm = False
    is_drf_text_only = True
    # exp_title = 'JackFram'

@ex.named_config
def Vicuna68m():
    drf = "double7/vicuna-68m"
    is_drf_from_mllm = False
    is_drf_text_only = True
    exp_title = 'double7'

@ex.named_config
def Vicuna160m():
    drf = "double7/vicuna-160m"
    is_drf_from_mllm = False
    is_drf_text_only = True
    exp_title = 'double7'

@ex.named_config
def Llava68m():
    drf = "XXXX-2/lvlm68m"

@ex.named_config
def Llava290m():
    drf = "XXXX-2/lvlm290m"

@ex.named_config
def BaseLlama68m():
    drf = "XXXX-2/lm68m"

@ex.named_config
def BaseLlama290m():
    drf = "XXXX-2/lm290m"

# Decoding
@ex.named_config
def ARDecoding():
    decoding = 'ard'
    # tgt = None # Todo

@ex.named_config
def SpecDecoding():
    decoding = 'sd'

@ex.named_config
def CascadeMT():
    drafting = ['multimodal', 'text-only']

@ex.named_config
def CascadeMC():
    drafting = ['multimodal', 'caption']

@ex.named_config
def CascadeTC():
    drafting = ['text-only', 'caption']

@ex.named_config
def EvalLvlm():
    eval_models = [
        ("XXXX-2/lvlm68m", "llava-hf/llava-1.5-7b-hf"),
    ]

@ex.named_config
def EvalLvlm13b():
    eval_models = [
        ("XXXX-2/lvlm68m", "llava-hf/llava-1.5-13b-hf"),
    ]

@ex.named_config
def EvalLvlm290m7b():
    eval_models = [
        ("XXXX-2/lvlm290m", "llava-hf/llava-1.5-7b-hf"),
    ]

@ex.named_config
def EvalLvlm290m13b():
    eval_models = [
        ("XXXX-2/lvlm290m", "llava-hf/llava-1.5-13b-hf"),
    ]

@ex.named_config
def EvalQwen():    
    eval_models = [
        ("llava-hf/llava-interleave-qwen-0.5b-hf", "llava-hf/llava-interleave-qwen-7b-hf")
    ]

@ex.named_config
def EvalCascadeMT():
    eval_drafting = [['multimodal', 'text-only']]

@ex.named_config
def EvalCascadeMC():
    eval_drafting = [['multimodal', 'caption']]

@ex.named_config
def EvalCascadeTC():
    eval_drafting = [['text-only', 'caption']]

@ex.named_config
def EvalCascadeMTC():
    eval_drafting = [['multimodal', 'text-only', 'caption']]

@ex.named_config
def EvalCascadeMTCP():
    eval_drafting = [['multimodal', 'text-only', 'caption', 'image-pool']]

@ex.named_config
def EvalWholeData():
    eval_datasets = ["VibeEval", "DC100_EN", 'llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA', 'LiveBench']

@ex.named_config
def CascadeMTC():
    drafting = ['multimodal', 'text-only', 'caption']

@ex.named_config
def CascadeMTCP():
    drafting = ['multimodal', 'text-only', 'caption', 'image-pool']

# Drafting
@ex.named_config
def MultimodalDraft():
    drafting = 'multimodal'

@ex.named_config
def TextOnlyDraft():
    drafting = 'text-only'

@ex.named_config
def CaptionDraft():
    drafting = 'caption'

@ex.named_config
def TextOnlyVerify():
    is_tgt_text_only = True

@ex.named_config
def HalfPrecision():
    drf_dtype = "fp16" # torch.float32
    tgt_dtype = "fp16" # torch.float32

# Datasets
@ex.named_config
def LlavaData():
    dataset = "LLaVA-Instruct-150K"

@ex.named_config
def CocoData():
    dataset = "COCO2014"

@ex.named_config
def ScienceQAData():
    dataset = "ScienceQA"
    save_steps = 1000

@ex.named_config
def VibeEvalData():
    dataset = "VibeEval"

# @ex.named_config
# def LlavaBenchWilderData():
#     dataset = "LLaVA-Bench-Wilder"

@ex.named_config
def LlavaBenchInTheWildData():
    dataset = 'llava-bench-in-the-wild'

@ex.named_config
def Dc100Data():
    dataset = "DC100_EN"

@ex.named_config
def SpotTheDiffData():
    dataset = "Spot-the-Diff"

@ex.named_config
def BirdsToWordsData():
    dataset = "Birds-to-Words"

@ex.named_config
def ClevrChangeData():
    dataset = "CLEVR-Change"

@ex.named_config
def HQEditData():
    dataset = "HQ-Edit"

@ex.named_config
def MagicBrushData():
    dataset = "MagicBrush"

@ex.named_config
def IEditData():
    dataset = "IEdit"

@ex.named_config
def AESOPData():
    dataset = "AESOP"

@ex.named_config
def FlintstonesSVData():
    dataset = "FlintstonesSV"

@ex.named_config
def PororoSVData():
    dataset = "PororoSV"

@ex.named_config
def VISTData():
    dataset = "VIST"

@ex.named_config
def WebQAData():
    dataset = "WebQA"

@ex.named_config
def LiveBenchData():
    dataset = 'LiveBench'    

@ex.named_config
def ChartQAData():
    dataset = 'chartqa'

@ex.named_config
def DocVQAData():
    dataset = 'docvqa_val'

@ex.named_config
def InfoVQAData():
    dataset = 'infovqa_val'

@ex.named_config
def OkVQAData():
    dataset = 'ok_vqa_val2014'

@ex.named_config
def TextVQAData():
    dataset = 'textvqa_val'

@ex.named_config
def VizWizVQAData():
    dataset = 'vizwiz_vqa_val'

@ex.named_config
def VQAV2Data():
    dataset = 'vqav2_val'

@ex.named_config
def MMVetData():
    dataset = "MMVet"

@ex.named_config
def PopeData():
    dataset = "POPE"

@ex.named_config
def HallusionBenchData():
    dataset = "HallusionBench"

@ex.named_config
def QBenchData():
    dataset = "QBench"

@ex.named_config
def NLVR2MantisData():
    dataset = "NLVR2_Mantis"

@ex.named_config
def OCRVQAData():
    dataset = "OCR-VQA"

@ex.named_config
def Evaluation():
    # run: python3 mllmsd/utils/evaluation.py with Evaluation
    drf = None
    tgt = None
    decoding = None
    eval_models = [
        ("XXXX-2/lvlm68m", "None"),
        ("llava-hf/llava-1.5-7b-hf", "XXXX-2/lvlm68m"),
    ]
    eval_is_drf_text_only = [False]
    eval_datasets = ["VibeEval"]
    # eval_datasets = ["VibeEval", "DC100_EN", 'llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST']
    eval_max_chunk_length = []
    is_time_factorized = False
    exp_title = ''

    # datasets
    # eval_datasets = ["VibeEval", "DC100_EN", 'llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST']
        
    npy_save_dir = f"/XXXX-5/home-XXXX-3/data/MSD/npy"
    do_print = False

@ex.named_config
def EvaluationARD():
    # python3 mllmsd/utils/evaluation.py with EvaluationARD
    drf = None
    tgt = None
    decoding='ard'
    eval_models = [
        # ("XXXX-2/lvlm68m", "None"),
        ("llava-hf/llava-1.5-7b-hf", "None"),
        ("llava-hf/llava-1.5-13b-hf", "None"),
    ]
    eval_is_drf_text_only = [True]
    eval_drafting = ['multimodal']
    # eval_drafting = ['multimodal', 'text-only']
    eval_max_chunk_length = [5]
    eval_datasets = ["VibeEval", "DC100_EN", 'llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA']
    exp_title = 'fp16-bench-A6000'
        
    npy_save_dir = f"/XXXX-5/home-XXXX-3/data/MSD/npy"
    do_print = False

@ex.named_config
def EvaluationSD():
    # python3 mllmsd/utils/evaluation.py with EvaluationSD
    drf = None
    tgt = None
    decoding='sd'
    eval_models = [
        # ("XXXX-2/lvlm68m-pool-0-ft", "llava-hf/llava-1.5-7b-hf"),
        # ("XXXX-2/lvlm68m-pool-1-ft", "llava-hf/llava-1.5-7b-hf"),
        # ("XXXX-2/lvlm68m-pool-4-ft", "llava-hf/llava-1.5-7b-hf"),
        # ("XXXX-2/lvlm68m-pool-9-ft", "llava-hf/llava-1.5-7b-hf"),
        # ("XXXX-2/lvlm68m-pool-36-ft", "llava-hf/llava-1.5-7b-hf"),
        # ("XXXX-2/lvlm68m-pool-144-ft", "llava-hf/llava-1.5-7b-hf"),
        ("XXXX-2/lvlm68m", "llava-hf/llava-1.5-7b-hf"),
        # ("XXXX-2/lvlm68m", "llava-hf/llava-1.5-13b-hf"),
        # ("llava-hf/llava-interleave-qwen-0.5b-hf", "llava-hf/llava-interleave-qwen-7b-hf")
    ]
    
    output_image_attentions = False

    eval_is_drf_text_only = [False] # deprecated
    eval_max_chunk_length = [5]
    # eval_drafting = ['multimodal', 'text-only', 'caption', 'tokenized-image', 'special-token', 'image-pool'] # phase 2
    # eval_drafting = [['multimodal', 'caption']]
    # eval_drafting = [[ 'multimodal', 'text-only']]
    # eval_drafting = ['multimodal', 'text-only']
    # eval_drafting = ['multimodal']
    # eval_drafting = ['caption']
    # eval_drafting = ['image-pool']
    # eval_target_dim_image_pooling = [1, 9, 36, 144]
    # eval_image_top_k_attention = [1, 16, 64, 256]
    # eval_cascade_rule = ['mm-weight']
    # eval_mm_weight_policy = ['img-nec']
    # eval_mm_weight_policy = [1, 2, 3, 4]
    # eval_captioning_model = ["microsoft/Florence-2-large-ft"]
    # eval_caption_type = ["<MORE_DETAILED_CAPTION>"]
    # eval_caption_type = ["<OCR>"]
    # eval_caption_type = ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>"]
    # eval_mm_weight_k = [2, 3, 4, 8]
    is_time_factorized = False
    exp_title = "fp16-mm-weight-cascade-MC-image-necessity-k-A6000"
    # exp_title = "fp16-mm-weight-cascade-MT-image-necessity-k-A6000"
    # exp_title = "fp16-pool-ft-lvlm-A6000"
    # exp_title = "fp16-enhance-log-qwen-single-drf-A6000"
    # exp_title = "fp16-enhance-log-lvlm-single-drf-A6000"
    # exp_title = "fp16-mm-weight-cascade-MC-image-necessity-A6000"
    # exp_title = "fp16-mm-weight-cascade-MT-image-necessity-A6000"
    # exp_title = "fp16-caption-mdc-A6000"
    # exp_title = "fp16-mm-weight-cascade-MC-A6000"
    # exp_title = "fp16-mm-weight-cascade-MT-A6000"
    # tiny_data = True

    # datasets
    # eval_datasets = ['DC100_EN']
    # eval_datasets = ['llava-bench-in-the-wild']
    
    eval_datasets = ["VibeEval", "DC100_EN", 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit']
    # Except  livebench
    # eval_datasets = ["VibeEval", "DC100_EN", 'llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA']
    # Whole dataset
    # eval_datasets = ["VibeEval", "DC100_EN", 'llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA', 'LiveBench']
        
    npy_save_dir = f"/XXXX-5/home-XXXX-3/data/MSD/npy"
    do_print = False

@ex.named_config
def T5Measurement():
    # set 'drf' as the model to measure the time
    drf = "google/t5-small-lm-adapt"
    tgt = "google/t5-small-lm-adapt"
    
    dataset = "xsum"
    max_prompt_length=1024
    max_target_length=64

    # "google/t5-small-lm-adapt": "T5lm-small",
    # "google/t5-base-lm-adapt": "T5lm-base",
    # "google/t5-large-lm-adapt": "T5lm-large",
    # "google/t5-xl-lm-adapt": "T5lm-xl",
    # "google/t5-xxl-lm-adapt": "T5lm-xxl"is_drf_from_mllm,
    pass

@ex.named_config
def Debug():
    # drf = "llava-hf/llava-interleave-qwen-0.5b-hf"
    # tgt = "llava-hf/llava-interleave-qwen-0.5b-hf"
    # tgt = None
    # drf = "XXXX-2/lvlm68m-pool-1-ft"
    drf = "XXXX-2/lvlm68m"
    # tgt = "llava-hf/llava-1.5-7b-hf"
    tgt = "XXXX-2/lvlm68m"
    # drf = "XXXX-2/lvlm290m"
    # tgt = "XXXX-2/lvlm290m"
    # is_drf_text_only = False
    # tgt = "XXXX-2/lm68m"
    # tgt = None

    # logging_top_k = 3
    # drafting = 'caption'
    # drafting = ['multimodal', 'text-only']
    # drafting = ['multimodal', 'text-only', 'caption']
    drafting = ['multimodal', 'multimodal-debug', 'multimodal-debug2']
    captioning_model = "microsoft/Florence-2-large-ft"
    # cascade_rule = 'confidence'
    # cascade_rule = 'dist-sum'
    cascade_rule = 'mm-weight'
    mm_weight_policy = 1
    # mm_weight_k = 2
    # mm_weight_policy = 1
    # drafting = ['multimodal', 'multimodal-debug']
    # drafting = ['multimodal', 'text-only']
    max_target_length = 128
    # image_top_k_attention = 8
    # is_drf_text_only = True
    # target_dim_image_pooling = 1
    # image_pool_type = 'avg2d'
    caption_type = "<OCR>" 
    # caption_type = "<MORE_DETAILED_CAPTION>"
    decoding = 'sd'
    drf_dtype = "fp16"
    tgt_dtype = "fp16"
    # drf_dtype = "fp32"
    # tgt_dtype = "fp32"

    do_print = True
    is_time_factorized = False
    debug = True
    save_steps = 5
    max_chunk_length = 5
    output_image_attentions = False
    
    dataset = 'DC100_EN'
    # dataset = "HallusionBench"
    # dataset = "POPE"
    # dataset = "HallusionBench"
    # dataset = "QBench"
    # dataset = "NLVR2_Mantis"
    # dataset = "OCR-VQA"

    # 'chartqa'
    # 'docvqa_val'
    # 'infovqa_val'
    # 'ok_vqa_val2014'
    # 'textvqa_val'
    # 'vizwiz_vqa_val'
    # 'vqav2_val'
    # dataset = 'llava-bench-in-the-wild'
    tiny_data = True
    classifier_top_k = 5

"""
mm weight
['1.0820', '0.7559', '0.6094', '1.0004', '0.9820']
['1.0007', '0.9293', '1.0030', '0.9879', '1.2031']
['1.6235', '1.4316', '1.3548', '1.2343', '1.0068']
['0.6992', '0.9940', '0.0000', '0.5626', '0.0452']
['1.2533', '1.1588', '1.5632', '0.8723', '0.9987']

dist
['1.0820', '0.7559', '0.6094', '1.0004', '0.9820']
['1.0007', '0.9293', '1.0030', '0.9879', '1.2031']
['1.6235', '1.4316', '1.3548', '1.2343', '1.0068']
['0.6992', '0.9940', '0.0000', '0.5626', '0.0452']
['1.2533', '1.1588', '1.5632', '0.8723', '0.9987']

['1.0820', '0.7559', '0.6094', '1.0004', '0.9820']
['1.0007', '0.9293', '1.0030', '0.9879', '1.2031']
['1.6235', '1.4316', '1.3548', '1.2343', '1.0068']
['0.6992', '0.9940', '0.0000', '0.5626', '0.0452']
['1.2533', '1.1588', '1.5632', '0.8723', '0.9987']

['1.0820', '0.7559', '0.6094', '1.0004', '0.9820']
['1.0007', '0.9293', '1.0030', '0.9879', '1.2031']
['1.6235', '1.4316', '1.3548', '1.2343', '1.0068']
['0.6992', '0.9940', '0.0000', '0.5626', '0.0452']
['1.2533', '1.1588', '1.5632', '0.8723', '0.9987']
"""

