import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from datasets import load_dataset, load_from_disk, DatasetDict
from collections import Counter
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from itertools import chain
import re
import json
import bisect
import string
from utils import get_dataset_type, is_none, dataset_name_split_mapping
import argparse
import transformers
from transformers import (
    AutoTokenizer,
    LlamaTokenizer,
    default_data_collator,
)
# from datasets import disable_caching
# disable_caching()

zero_shot_template = {
    "multi-choice-question": "Question: {question}",
    "multi-choice-question-unit": "{question} (Unit: {unit})",
    "multi-choice-hint-question": "Hint: {hint}\n{question}",
    "multi-choice-question-option": "{question}\nOptions:",
    "multi-choice-question-key-item": "{question}\n{key}. {item}",
    "multi-choice-question-choices-ppl-list": [
        "{question}\nAnswer:",
        "{question}\nThe answer is: \"",
        "{question}\nThe answer is \"",
        "{question}\nThe answer is '",
        "{question}\nPlease select the correct answer from the options above based on the image. The answer is: \"",
        "{question}\nPlease select the correct answer from the options above based on the image. The answer is \"",
        "Please select the correct answer from the options above based on the image.\n{question}\nAnswer:",
        "Please select the correct answer from the options above based on the image.\n{question}\nAnswer: \"",
        # "Please select the correct answer from the options above based on the image.\n{question}\nThe answer is: \"",
        # "Please select the correct answer from the options above based on the image.\n{question}\nThe answer is \"",
    ],
    "multi-choice-question-choices-ppl-list-no-image": [
        "{question}\nPlease select the correct answer from the options above. The answer is \"",
        "{question}\nPlease select the correct answer from the options above.\nThe answer is:",
        "{question}\nPlease select the correct answer from the options above. The answer is: \"",
        "{question}\nThe answer is \"",
        "Please select the correct answer from the options above.\n{question}\nAnswer: \"",
        "{question}\nThe answer is: \"",
        "Please select the correct answer from the options above.\n{question}\nThe answer is:",
        "{question}\nThe answer is '",
    ],
    "multi-choice-question-list": [
        "{question}\nAnswer with the option's letter from the given options directly. The answer is"
    ],
    "Y/N-Winorground-choices-ppl-caption-list": [
        "Please pay close attention to the word order and directly answer 'yes' or 'no'.\nQuestion: Does the image match the caption \"{caption}\"?",
        "Please pay close attention to the word order and directly answer 'yes' or 'no'.\nQuestion: Does the image match the caption \"{caption}\"? The answer is: '",
        "Please pay close attention to the word order and directly answer 'yes' or 'no'.\nQuestion: Does the image match the caption \"{caption}\"?\nThe answer is: '",
        "Please directly answer 'yes' or 'no'.\nQuestion: Pay close attention to the word order, does the image match the caption \"{caption}\"? The answer is: '",
        "Question: Does the image match the caption \"{caption}\"? '",
        "Does the image match the caption \"{caption}\"? '",
        "Question: Does the image match the caption \"{caption}\"?",
        "Does the image match the caption \"{caption}\"?",
    ],
    "Y/N-MME-choices-ppl-question-list": [
        "{question}",
        "{question}\nThe answer is: '", #
        "Question: {question} The answer is: '",
        "{question} Please answer 'yes' or 'no' based on the image. The answer is: '",
        "Question: {question} Please answer 'yes' or 'no' based on the image. The answer is: '",
        "Question: {question}\nPlease answer 'yes' or 'no' based on the image. The answer is: '", #
        "Please answer 'yes' or 'no' based on the image.\n{question}",
        "Please answer 'yes' or 'no' based on the image. {question}\nThe answer is: '",
        # "Please answer 'yes' or 'no' based on the image.\n{question}\nThe answer is: '",
        # "Please answer 'yes' or 'no' based on the image.\nQuestion: {question} The answer is: '",
    ],
    "Y/N-HallusionBench-choices-ppl-question-list": [
        "{question}\nPlease answer 'yes' or 'no' based on the image. The answer is: ", #
        "{question}\nPlease answer 'yes' or 'no' based on the image. The answer is: '",
        "Question: {question} Please answer 'yes' or 'no' based on the image. The answer is: ", #
        "Question: {question}\nPlease answer 'yes' or 'no' based on the image. The answer is: '",
        "Question: {question}",
        "{question} Please answer 'yes' or 'no' based on the image. The answer is: '", #
        "Please answer 'yes' or 'no' based on the image.\nQuestion: {question}",
        "Please answer 'yes' or 'no' based on the image.\n{question}",

        # "{question}",
        # "Please answer 'yes' or 'no' based on the image.\nQuestion: {question} The answer is: '",
        # "Please answer 'yes' or 'no' based on the image.\n{question} The answer is: '",
        # "Please answer 'yes' or 'no' based on the image.\nQuestion: {question}\nThe answer is: '",
        # "Please answer 'yes' or 'no' based on the image. {question} The answer is: '",
        # "{question} Please answer 'yes' or 'no' based on the image.\nThe answer is: '",
        # "{question}\nPlease answer 'yes' or 'no' based on the image.\nThe answer is: '",
        # "{question}\nPlease answer 'yes' or 'no' based on the image.\nThe answer is:",
    ],
    "Y/N-POPE-choices-ppl-question-list": [
        "Question: {question} The answer is: '",
        "{question} Please answer yes or no based on the image. The answer is: '",
        "{question}\nPlease answer yes or no based on the image. The answer is: '",
        "Question: {question} Please answer yes or no based on the image. The answer is: '",
        "Question: {question}\nPlease answer yes or no based on the image. The answer is: '", #
        "Question: {question}\nPlease answer 'yes' or 'no' based on the image. The answer is: '", #
        "Please answer 'yes' or 'no' based on the image.\nQuestion: {question}\nThe answer is: '",
        "Please answer 'yes' or 'no' based on the image.\n{question} The answer is: '",

        # "{question}",
        # "{question}\nThe answer is: '",
        # "Question: {question} The answer is:",
        # "{question} Please answer 'yes' or 'no' based on the image. The answer is: '",
        # "Question: {question} Please answer 'yes' or 'no' based on the image. The answer is: '",
        # "Please answer 'yes' or 'no' based on the image.\n{question}",
        # "Please answer 'yes' or 'no' based on the image. {question} The answer is: '",
        # "Please answer 'yes' or 'no' based on the image. {question}\nThe answer is: '",
    ],
    "Y/N-Winorground-caption-list": [
        # "Caption: {caption}\nQuestion: Does the image match the caption? The answer is"
        "Question: Does the image match the caption \"{caption}\"? The answer is"
    ],
    "Y/N-question-list": [
        "Question: {question}\nThe answer is"
    ],
    "VQA-question-unit": "{question} (Unit: {unit})",
    "VQA-hint-question": "Hint: {hint}\n{question}",
    "VQA-VizWiz-question-list": [
        "When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase. Question: {question} The answer is \"",
        "When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase.\nQuestion: {question} The answer is \"",
        "When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase.\nQuestion: {question}\nThe answer is \"",
        "When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase. Question: {question} Answer: \"",
        "When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase.\nQuestion: {question} Answer: \"",
        "When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase.\nQuestion: {question} Answer:",
        "Question: {question} When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase. The answer is \"",
        "Question: {question}\nWhen the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase.\nThe answer is \"",
    ],
    "VQA-question-list": [
        "Answer the question using a single word or phrase. Question: {question} The answer is \"",
        "Answer the question using a single word or phrase.\nQuestion: {question} The answer is \"",
        "Answer the question using a single word or phrase.\nQuestion: {question}\nThe answer is \"",
        "Answer the question using a single word or phrase.\nQuestion: {question} The answer is",
        "Question: {question} Answer the question using a single word or phrase. The answer is \"",
        "Question: {question}\nAnswer the question using a single word or phrase. The answer is \"",
        "Question: {question}\nAnswer the question using a single word or phrase.\nThe answer is \"",
        "Question: {question} Answer the question using a single word or phrase. The answer is:",
    ],
    "Caption-list": [
        "Caption:"
    ]
}

shot_start = "# Few-shot Examples\n\n"
shot_delimiter = "\n\n"
shot_end = "<s># New Input for Processing\n\n" # <s> will be removed, it is used to concat tokenized tokens and make the tokenized tokens same to "concat texts and then tokenize"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf', help='path to LaVIT checkpoint')
    parser.add_argument('--output_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation', help='path to save the output')
    
    parser.add_argument("--from_hf", action="store_true", help="Whether to use huggingface datasets.")
    parser.add_argument('--dataset_name', type=str, default='MMBench_DEV_EN', help='dataset name')
    parser.add_argument('--prompt_settings', type=str, default='zero_shot,zero_shot_choices_ppl', help='prompt settings')
    parser.add_argument('--num_templates', type=str, default="1,1", help='list of number of supported templates for different prompt settings')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--image_pad_token_id', type=int, default=-1, help='pad token id')
    parser.add_argument('--image_start_token_id', type=int, default=32000, help='image start token id')
    parser.add_argument('--image_end_token_id', type=int, default=32001, help='image end token id')
    parser.add_argument('--max_length', type=int, default=2048, help='max length')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    print('Number of available cores:', multiprocessing.cpu_count())
    args = parser.parse_args()
    
    return args

def statistics(dataset_dict, dataset_name, prompt_settings, num_templates):
    for prompt_setting, num_template in zip(prompt_settings, num_templates):
        for template_index in range(num_template):
            length_list = [len(x) for x in dataset_dict[f"{prompt_setting}_{template_index}"]['input_tokens']]
            input_tokens_length_mean = np.mean(length_list)
            input_tokens_length_max = np.max(length_list)
            input_tokens_length_min = np.min(length_list)
            print(f"###############{dataset_name}_{prompt_setting}_{template_index}: {dataset_dict[f'{prompt_setting}_{template_index}'].num_rows}###############")
            print(f"input_tokens_length mean: {input_tokens_length_mean:.2f}, input_tokens_length_max: {input_tokens_length_max}, input_tokens_length_min: {input_tokens_length_min}")

def test():
    args = parse_args()

    eval_path = os.path.join(args.output_dir, args.dataset_name, 'eval')
    final_dataset = load_from_disk(eval_path)
    
    dataset_type = get_dataset_type(args.dataset_name)
    # prompt_settings = ['zero_shot', 'zero_shot_cot', 'zero_shot_cod']
    prompt_settings = args.prompt_settings.split(',')
    if dataset_type not in ['multi-choice', 'Y/N']:
        prompt_settings = [prompt_setting for prompt_setting in prompt_settings if not prompt_setting.endswith('_choices_ppl')]
    args.prompt_settings = prompt_settings
    args.num_templates = args.num_templates.split(',')[:len(args.prompt_settings)]
    args.num_templates = [int(num_template) for num_template in args.num_templates]

    # statistics(final_dataset, args.dataset_name, args.prompt_settings, args.num_templates)
    
    # llama_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, subfolder='language_model', use_fast=False)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, legacy=False)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']}, replace_additional_special_tokens=False)
    tokenizer.add_special_tokens({'additional_special_tokens': ['</image>']}, replace_additional_special_tokens=False)
    # tokenizer.add_tokens([f"<image_{str(i)}>" for i in range(16384)]) # 48386-32000-2
    image_start_token = tokenizer.additional_special_tokens[0]
    image_end_token = tokenizer.additional_special_tokens[1]
    image_start_token_id = tokenizer.additional_special_tokens_ids[0]
    image_end_token_id = tokenizer.additional_special_tokens_ids[1]
    assert image_start_token_id == args.image_start_token_id
    
    for prompt_setting, num_template in zip(args.prompt_settings, args.num_templates):
        for template_index in range(num_template):
            input_index = random.randint(0, final_dataset[f"{prompt_setting}_{template_index}"].num_rows)
            input_tokens = final_dataset[f"{prompt_setting}_{template_index}"][input_index]['input_tokens']
            
            image_token_lists = []
            text_token_lists = []
            image_start = False
            text_start = False
            for token_id in input_tokens:
                if token_id >= args.image_start_token_id:
                    if not image_start:
                        image_token_list = []
                        image_start = True
                    if text_start:
                        text_token_lists.append(text_token_list)
                        text_start = False
                    image_token_list.append(token_id)
                else:
                    if not text_start:
                        text_token_list = []
                        text_start = True
                    if image_start:
                        image_token_lists.append(image_token_list)
                        image_start = False
                    text_token_list.append(token_id)

            if image_start:
                image_token_lists.append(image_token_list)
            if text_start:
                text_token_lists.append(text_token_list)
            if image_start and text_start:
                raise ValueError("image_start and text_start are both True")

            text_tokens = [tokenizer.decode(text_token_list) for text_token_list in text_token_lists]
            text_tokens = "[image]".join(text_tokens)
            print(f"input_index: {input_index}, text_token_lists: {len(text_token_lists)}, image_token_lists: {len(image_token_lists)}")
            print(text_tokens)
            if args.dataset_name not in ["MMLU", "PIQA_VAL"]:
                print(image_token_lists[0])

def main():
    args = parse_args()
    
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # llama_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, subfolder='language_model', use_fast=False)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, legacy=False)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']}, replace_additional_special_tokens=False)
    tokenizer.add_special_tokens({'additional_special_tokens': ['</image>']}, replace_additional_special_tokens=False)
    # tokenizer.add_tokens([f"<image_{str(i)}>" for i in range(16384)]) # 48386-32000-2
    image_start_token = tokenizer.additional_special_tokens[0]
    image_end_token = tokenizer.additional_special_tokens[1]
    image_start_token_id = tokenizer.additional_special_tokens_ids[0]
    image_end_token_id = tokenizer.additional_special_tokens_ids[1]
    assert image_start_token_id == args.image_start_token_id

    image_token_path = os.path.join(args.output_dir, args.dataset_name, 'image_token')
    if args.from_hf:
        text_path = os.path.join(args.output_dir, args.dataset_name)
    else:
        text_path = os.path.join(args.output_dir, args.dataset_name, 'datasets')
    if args.dataset_name not in ["MMLU", "PIQA_VAL"]:
        image_token_dataset = load_from_disk(image_token_path)
        print(image_token_dataset)
    text_dataset = load_from_disk(text_path)
    dataset_type = get_dataset_type(args.dataset_name)
    print(text_dataset)
    print(args.dataset_name, dataset_type)
    text_dataset.cleanup_cache_files()
    # prompt_settings = ['zero_shot', 'zero_shot_cot', 'zero_shot_cod']
    prompt_settings = args.prompt_settings.split(',')
    if dataset_type not in ['multi-choice', 'Y/N']:
        prompt_settings = [prompt_setting for prompt_setting in prompt_settings if not prompt_setting.endswith('_choices_ppl')]
    args.prompt_settings = prompt_settings
    args.num_templates = args.num_templates.split(',')[:len(args.prompt_settings)]
    args.num_templates = [int(num_template) for num_template in args.num_templates]

    def tokenize_by_llama(inputs):
        return tokenizer(inputs, add_special_tokens=False)["input_ids"]

    def replace_image_tokens(input_token_ids, image_token_ids):
        # when input only contain one image
        image_position_index = input_token_ids.index(image_start_token_id)
        new_input_token_ids =  input_token_ids[:image_position_index + 1] + image_token_ids + [image_end_token_id] + input_token_ids[image_position_index + 1:]
        # remove the additional blank token
        if new_input_token_ids[0] == 7084:
            new_input_token_ids[0] = 2940
        return new_input_token_ids

    def prepare_few_shot_examples(prompt_setting, template_index):
        if args.dataset_name in ["MMLU"]:
            output_examples = {}
        
        if dataset_type == 'multi-choice' and args.dataset_name in ["MMLU"]:
            # MMLU use dev split as few-shot examples
            subjects = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
            
            for subject in subjects:
                task_prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(subject.replace("_", " "))
                output_examples[subject] = [task_prompt]
                examples = text_dataset['dev'].filter(lambda x: x['subject'] == subject)
                for example_index in range(examples.num_rows):
                    question = zero_shot_template["multi-choice-question"].format(question=examples['question'][example_index].strip())
                    options = {
                        cand: examples["choices"][example_index][cand_index]
                        for cand_index, cand in enumerate(string.ascii_uppercase[:len(examples["choices"][example_index])])
                    }

                    if len(options):
                        question = zero_shot_template["multi-choice-question-option"].format(question=question.strip())
                        for key, item in options.items():
                            question = zero_shot_template["multi-choice-question-key-item"].format(question=question.strip(), key=key.strip(), item=item.strip())
                        if prompt_setting.endswith('_choices_ppl'):
                            question = zero_shot_template["multi-choice-question-choices-ppl-list-no-image"][template_index].format(question=question.strip())
                        else:
                            question = zero_shot_template["multi-choice-question-list"][template_index].format(question=question.strip())
                    else:
                        raise ValueError("options is empty")
                    
                    if question.endswith('"'):
                        question = f"{question}{examples['answer_transformed'][example_index]}\"{shot_delimiter}"
                    elif question.endswith("'"):
                        question = f"{question}{examples['answer_transformed'][example_index]}'{shot_delimiter}"
                    else:
                        question = f"{question} {examples['answer_transformed'][example_index]}{shot_delimiter}"

                    output_examples[subject].append(question)
                
                output_examples[subject] = tokenize_by_llama("".join(output_examples[subject]))

        return output_examples
    
    def prompt_and_tokenize_function(examples, idxes, prompt_setting, template_index, current_few_shot_examples):
        examples['input_tokens'] = []
        examples['input_index'] = idxes                
        
        if dataset_type in ['multi-choice']:
            examples['choice_count'] = []

        if prompt_setting in ['zero_shot_cot', 'zero_shot_cot_choices_ppl']:
            examples['image_tokens'] = []
            examples['first_round_text'] = []
            examples['second_round_text'] = []
        
        if prompt_setting in ['zero_shot_cod', 'zero_shot_cod_choices_ppl']:
            examples['image_tokens'] = []
            examples['first_round_text'] = []
            examples['second_round_text'] = []
            examples['third_round_text'] = []
            examples['fourth_round_text'] = []
            examples['fifth_round_text'] = []
        
        for i in range(len(examples['input_index'])):
            if dataset_type == 'multi-choice':
                # question = f"Question: {examples['question'][i]}"
                question = zero_shot_template["multi-choice-question"].format(question=examples['question'][i].strip())
                hint = examples['hint'][i] if 'hint' in examples and not is_none(examples['hint'][i]) else None
                
                if args.dataset_name in ['MathVista_MultiChoice']:
                    hint = examples['query'][i].split('Question: ')[0].split('Hint: ')[1]
                    assert hint != ""
                    if examples['unit'][i] is not None:
                        # print(examples['unit'][i])
                        # question = f"{question} (Unit: {examples['unit'][i]})"
                        question = zero_shot_template["multi-choice-question-unit"].format(question=question.strip(), unit=examples['unit'][i].strip())

                if not is_none(hint):
                    # question = f"Hint: {hint.strip()}\n{question}"
                    question = zero_shot_template["multi-choice-hint-question"].format(hint=hint.strip(), question=question.strip())
                    # question = f"{hint}\n{question}"
                
                if args.dataset_name in ['MMMU_VAL_MultiChoice', 'MathVista_MultiChoice']:
                    if args.dataset_name in ['MMMU_VAL_MultiChoice']:
                        options = eval(examples['options'][i])
                    elif args.dataset_name in ['MathVista_MultiChoice']:
                        options = examples['choices'][i]
                    num_option = len(options)
                    options_dict = {
                        cand: options[cand_idx]
                        for cand_idx, cand in enumerate(string.ascii_uppercase[:num_option])
                    }
                    options = options_dict
                
                elif args.dataset_name in ["MMLU"]:
                    options = {
                        cand: examples["choices"][i][cand_index]
                        for cand_index, cand in enumerate(string.ascii_uppercase[:len(examples["choices"][i])])
                    }
                else:
                    options = {
                        cand: examples[cand][i]
                        for cand in string.ascii_uppercase
                        if cand in examples and not is_none(examples[cand][i])
                    }
                    
                examples['choice_count'].append(len(options))

                if len(options):
                    # question = f"{question}\nOptions:"
                    question = zero_shot_template["multi-choice-question-option"].format(question=question.strip())
                    for key, item in options.items():
                        # question = f"{question}\n{key}. {item}"
                        question = zero_shot_template["multi-choice-question-key-item"].format(question=question.strip(), key=key.strip(), item=item.strip())
                    if prompt_setting.endswith('_choices_ppl'):
                        if args.dataset_name in ["MMLU", "PIQA_VAL"]:
                            question = zero_shot_template["multi-choice-question-choices-ppl-list-no-image"][template_index].format(question=question.strip())
                            if args.dataset_name in ["PIQA_VAL"]:
                                question = "The following are multiple choice questions which requires physical commonsense to be answered correctly.\n" + question
                        else:
                            question = zero_shot_template["multi-choice-question-choices-ppl-list"][template_index].format(question=question.strip())
                    else:
                        question = zero_shot_template["multi-choice-question-list"][template_index].format(question=question.strip())
                else:
                    raise ValueError("options is empty")
                    # if prompt_setting.endswith('_choices_ppl'):
                    #     raise ValueError(f"prompt_setting: {prompt_setting} is not supported for zero option setting")
                    #     # question = f"{question}\nAnswer the question directly.\nAnswer:"
                    # else:
                    #     question = [
                    #         f"{question}\nAnswer the question directly. The answer is"
                    #     ][template_index]
            
            elif dataset_type == 'Y/N':
                if prompt_setting.endswith('_choices_ppl'):
                    if args.dataset_name in ['Winoground-YN']:
                        # # old template, put caption in the front
                        # caption = f"Caption: {examples['caption'][i]}\n"
                        # question = [
                        #     f"{caption}Please pay close attention to the word order and directly answer 'yes' or 'no'.\nQuestion: Does the image match the caption?",
                        #     f"{caption}Please pay close attention to the word order and directly answer 'yes' or 'no'.\nQuestion: Does the image match the caption? The answer is: '",
                        #     f"{caption}Please pay close attention to the word order and directly answer 'yes' or 'no'.\nQuestion: Does the image match the caption?\nThe answer is: '",
                        #     f"{caption}Please directly answer 'yes' or 'no'.\nQuestion: Pay close attention to the word order, does the image match the caption? The answer is: '",
                        #     f"{caption}Question: Does the image match the caption? '",
                        #     f"{caption}Does the image match the caption? '",
                        #     f"{caption}Question: Does the image match the caption?",
                        #     f"{caption}Does the image match the caption?",
                        # ][template_index]
                        question = zero_shot_template["Y/N-Winorground-choices-ppl-caption-list"][template_index].format(caption=examples['caption'][i].strip())
                    
                    if args.dataset_name in ['MME']:
                        examples['question'][i] = examples['question'][i].replace(" Please answer yes or no.", "")
                        question = zero_shot_template["Y/N-MME-choices-ppl-question-list"][template_index].format(question=examples['question'][i].strip())
                    
                    if args.dataset_name in ['HallusionBench']:
                        examples['question'][i] = examples['question'][i].replace(" Answer in one sentence.", "")
                        examples['question'][i] = examples['question'][i].replace(" Please answer yes or no.", "")
                        examples['question'][i] = examples['question'][i].replace(" Answer in one word.", "")
                        examples['question'][i] = examples['question'][i].replace(" Yes or No", "")
                        question = zero_shot_template["Y/N-HallusionBench-choices-ppl-question-list"][template_index].format(question=examples['question'][i].strip())
                    
                    if args.dataset_name in ['POPE']:
                        question = zero_shot_template["Y/N-POPE-choices-ppl-question-list"][template_index].format(question=examples['question'][i].strip())
                else:
                    if args.dataset_name in ['Winoground-YN']:
                        question = zero_shot_template["Y/N-Winorground-caption-list"][template_index].format(caption=examples['caption'][i].strip())
                    else:
                        question = zero_shot_template["Y/N-question-list"][template_index].format(question=examples['question'][i].strip())

            elif dataset_type == 'VQA':
                if args.dataset_name in ['MathVista_OpenEnded']:
                    hint = examples['query'][i].split('Question: ')[0].split('Hint: ')[1]
                    assert hint != ""
                    if examples['unit'][i] is not None:
                        # print(examples['unit'][i])
                        # examples['question'][i] = f"{examples['question'][i]} (Unit: {examples['unit'][i]})"
                        examples['question'][i] = zero_shot_template["VQA-question-unit"].format(question=examples['question'][i].strip(), unit=examples['unit'][i].strip())

                if args.dataset_name in ["VizWiz_VAL", "VizWiz_TEST"]:
                    question = zero_shot_template["VQA-VizWiz-question-list"][template_index].format(question=examples['question'][i].strip())
                else:
                    question = zero_shot_template["VQA-question-list"][template_index].format(question=examples['question'][i].strip())

                if args.dataset_name in ['MathVista_OpenEnded']:
                    # question = f"Hint: {hint.strip()}\n{question}"
                    question = zero_shot_template["VQA-hint-question"].format(hint=hint.strip(), question=question.strip())

            elif dataset_type == 'Caption':
                # question = "Based on the image, give the image caption briefly.\nCaption:"
                # question = [
                #     "Caption:",
                # ][template_index]
                question = zero_shot_template["Caption-list"][template_index]
            
            else:
                raise ValueError(f"dataset_type: {dataset_type} is not supported")

            if not args.from_hf: # or args.dataset_name in ["WHOOPS-VQA"]
                image_index = examples['image_index'][i]
            else:
                image_index = idxes[i]
            
            if args.dataset_name in ["MMLU", "PIQA_VAL"]:
                image_index = None

            if image_index is not None:
                image = f"Image: {image_start_token}\n"
            else:
                image = ""
            
            if args.dataset_name in ["MMMU_VAL_MultiChoice", "MMMU_VAL_OpenEnded"]:
                image_index_sequence = re.findall(r'<image (\d+)>', question)
                image_index_sequence = [int(x) - 1 for x in image_index_sequence]
                question = re.sub(r'<image (\d+)>', image_start_token, question)
                input_tokens = tokenize_by_llama(f"{question}")
            else:
                input_tokens = tokenize_by_llama(f"{image}{question}")
            
            if image_index is not None:
                if args.dataset_name in ["MMMU_VAL_MultiChoice", "MMMU_VAL_OpenEnded"]:
                    image_index_list = sample_index2image_index_list[image_index]
                    image_token_ids = []
                    for image_index in image_index_list:
                        image_token_ids.append(image_token_dataset['train'][image_index]['image_tokens'])
                        image_token_ids[-1] = image_token_ids[-1][:image_token_ids[-1].index(args.image_pad_token_id)]
                else:
                    if args.dataset_name in ["HallusionBench"]:
                        image_index = sample_index2image_index[image_index]
                    image_token_ids = image_token_dataset['train'][image_index]['image_tokens']
                    image_token_ids = [image_token_ids[:image_token_ids.index(args.image_pad_token_id)]]

                image_position_indexes = np.nonzero(np.array(input_tokens) == args.image_start_token_id)[0].tolist()
                assert len(image_position_indexes) >= 1, f"{idxes}, {i}, {question}"
                if args.dataset_name in ["MMMU_VAL_MultiChoice", "MMMU_VAL_OpenEnded"]:
                    assert len(image_index_sequence) == len(image_position_indexes)
                else:
                    image_index_sequence = list(range(len(image_position_indexes)))
                for image_position_index, idx in zip(image_position_indexes[::-1], image_index_sequence[::-1]):
                    # print(idxes[i], idx, image_position_index, len(input_tokens), len(image_token_ids), image_index_sequence, question)
                    input_tokens = input_tokens[:image_position_index + 1] + image_token_ids[idx] + [args.image_end_token_id] + input_tokens[image_position_index + 1:]

            # MMLU
            if args.dataset_name in ["MMLU"]:
                new_input_tokens = current_few_shot_examples[examples['subject'][i]]
                if input_tokens[0] == 894:
                    input_tokens[0] = 16492
                if input_tokens[0] == 3529:
                    input_tokens[0] = 12148
                input_tokens = new_input_tokens + input_tokens

            examples['input_tokens'].append([tokenizer.bos_token_id] + input_tokens)

            if prompt_setting in ['zero_shot_cot', 'zero_shot_cot_choices_ppl']:
                if image == "": # for HallusionBench's question without image, and MMLU
                    examples['image_tokens'].append(examples['input_tokens'][-1])
                    examples['first_round_text'].append("")
                    examples['second_round_text'].append("")
                else:
                    examples['image_tokens'].append(image_token_ids[0])
                    # "# Detailed Analysis of Objects in the Image" ?
                    examples['first_round_text'].append(f"{image}Caption:")
                    examples['second_round_text'].append(f"\n{question}")
            
            elif prompt_setting in ['zero_shot_cod', 'zero_shot_cod_choices_ppl']:
                if image == "": # for HallusionBench's question without image
                    examples['image_tokens'].append(examples['input_tokens'][-1])
                    examples['first_round_text'].append("")
                    examples['second_round_text'].append("")
                    examples['third_round_text'].append("")
                    examples['fourth_round_text'].append("")
                    examples['fifth_round_text'].append("")
                else:
                    examples['image_tokens'].append(image_token_ids[0])
                    # "# Detailed Analysis of Objects in the Image" ?
                    examples['first_round_text'].append(f"# Detailed Analysis of Objects in the Image\n\n{image}Caption:")
                    examples['second_round_text'].append("\nObjects:\n-")
                    examples['third_round_text'].append("\nAttributes of objects:\n-")
                    examples['fourth_round_text'].append("\nRelationships between objects:\n-")
                    examples['fifth_round_text'].append(f"\n{question}")
                    # examples['second_round_text'].append("\nObjects and their descriptions:\n-")
                    # examples['third_round_text'].append("\nAttributes of objects and their descriptions:\n-")
                    # examples['fourth_round_text'].append("\nRelationships between objects and their descriptions:\n-")
        
        return examples

    combined_dataset = DatasetDict()
    if args.from_hf:
        dataset_split = dataset_name_split_mapping[args.dataset_name]
    else:
        dataset_split = 'train'
    
    # for datasets with question without image, we need to map the image index to row index
    if args.dataset_name in ["HallusionBench"]:
        sample_index2image_index = {}
        images_path = os.path.join(args.output_dir, args.dataset_name, 'images')
        images_files = os.listdir(images_path)
        images_files.sort()
        for i, image_file in enumerate(images_files):
            image_file = int(image_file.split('.')[0])
            sample_index2image_index[image_file] = i
    # for datasets with question with multiple images, we need to map the image index to row index list
    elif args.dataset_name in ["MMMU_VAL_MultiChoice", "MMMU_VAL_OpenEnded"]:
        sample_index2image_index_list = {}
        image_index = 0 
        for sample_index in range(text_dataset[dataset_split].num_rows):
            sample = text_dataset[dataset_split][sample_index]
            sample_index2image_index_list[sample_index] = []
            for i in range(1, 8):
                if sample[f"image_{i}"] is not None:
                    sample_index2image_index_list[sample_index].append(image_index)
                    image_index += 1
    
    for prompt_setting, num_template in zip(args.prompt_settings, args.num_templates):
        for template_index in range(num_template):
            if args.dataset_name in ["MMLU"]:
                current_few_shot_examples = prepare_few_shot_examples(prompt_setting, template_index)
            else:
                current_few_shot_examples = None
            # print(f"current_few_shot_examples: {current_few_shot_examples}")
            combined_dataset.cleanup_cache_files()
            combined_dataset[f"{prompt_setting}_{template_index}"] = text_dataset[dataset_split].map(
                lambda examples, idxes: prompt_and_tokenize_function(examples, idxes, prompt_setting, template_index, current_few_shot_examples),
                batched=True,
                with_indices=True,
                batch_size=args.process_batch_size,
                num_proc=args.process_num_workers if "VQAv2" in args.dataset_name else None,
                remove_columns=text_dataset[dataset_split].column_names,
                desc="Construct final dataset",
            )
    
    eval_path = os.path.join(args.output_dir, args.dataset_name, 'eval')
    os.makedirs(eval_path, exist_ok=True)
    combined_dataset.save_to_disk(eval_path, max_shard_size="20GB")
    print(combined_dataset)
    if dataset_type == 'multi-choice':
        for prompt_setting, num_template in zip(args.prompt_settings, args.num_templates):
            for template_index in range(num_template):
                choice_count_set = set(combined_dataset[f"{prompt_setting}_{template_index}"]['choice_count'])
                if len(choice_count_set) > 1:
                    print(f"{args.dataset_name}_{prompt_setting}_{template_index} has different choice count: {choice_count_set}")
                else:
                    print(f"{args.dataset_name}_{prompt_setting}_{template_index}  has the same choice count: {choice_count_set}")
    # statistics(combined_dataset, args.dataset_name, args.prompt_settings, args.num_templates)
    
if __name__ == "__main__":
    main()
    test()
