import os
import json
import numpy as np
from m4c_evaluator import EvalAIAnswerProcessor, TextVQAAccuracyEvaluator
from vqa_metric import compute_vqa_accuracy
from ok_vqa_utils import OKVQAStemmer
from eval_score_coco import coco_caption_eval
from datasets import Dataset, load_from_disk
from eval_score_mmbench import mmbench_eval
import sys
import re
from collections import defaultdict

answer_processor = EvalAIAnswerProcessor()
okvqa_stemmer = OKVQAStemmer()

stop_tokens = [
    # Special tokens
    "<s>",
    "</s>",
    "Image:",
    "Caption:",
    "Detailed Caption:",
    "Detailed caption:",
    "Localized Narrative Caption:",
    "Localized narrative caption:",
    "Style:",
    # WIT Template tokens
    "Short description:",
    "Detailed description:",
    "# Introduction",
    "# Introduction to a Concept",
    "## Concepts",
    "## Concepts Related to",
    "### Introduction",
    "### Introduction to a Related Concept",
    # OI Template tokens
    "Region:",
    "Object:",
    "Objects:",
    "Objects and their descriptions:",
    "Attribute:",
    "Attributes:",
    "Attributes of objects:",
    "Attributes of objects and their descriptions:",
    "Relationships between objects:",
    "Relationships between objects and their descriptions:",
    "Location in the image:",
    "Location of the selected region in the image:",
    "Nearby Objects:",
    "# Overview of Objects in the Image",
    "## Detailed Analysis of Selected Object Regions",
    "### Detailed Analysis of a Selected Object Region",
    "# Detailed Analysis of Objects in the Image",
    "## Overview of Selected Object Regions",
    "### Overview of a Selected Object Region",
    # SFT Template tokens
    "<<SYS>>",
    "<</SYS>>",
    "[INST]",
    "[/INST]",
    "Here is the image:",
    "Here is the edited image:",
    "What's happening in the scene?",
    "Explain the visual content of the image in great detail.",
    "Analyze the image in a comprehensive and detailed manner.",
    "What are the key elements in this picture?",
    "What do you see happening in this image?",
    "What is this photo about?",
    "Write a detailed description of the given image.",
    "Describe the following image.",
    "Can you elaborate on the elements of the picture provided?",
    "What do you think is going on in this snapshot?",
    "Can you describe the main features of this image for me?",
    "Provide a one-sentence caption for the provided image.",
    "Create an image that visually represents the description:",
    "Answer the question using a single word or phrase.",
    "Answer with the option’s letter from the given choices directly.",
]

stop_tokens_ids = None

# dataset_name to config
vllm_configs = {
    "multi-choice-ppl": {
        "max_tokens": 1,
        "temperature": 0,
    },
    "VQA": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,  
    },
    "Caption": {
        "max_tokens": 50,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "repetition_penalty": 1,
        "length_penalty": -1,
    },
    "CoD": {
        "max_tokens": 300, # 100, 200
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "repetition_penalty": 1,
        "length_penalty": -1,
    },
    # "MMBench_DEV_EN": {
    #     "max_tokens": 5,
    # },
    # "MME":{
    #     "max_tokens": 5,
    # },
    # "SEEDBench_IMG": {

    # },
    # "ScienceQA_VAL": {

    # },
    # "ScienceQA_TEST": {

    # },
    "COCO": {
        "max_tokens": 30, # 8, 48, 10.6
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "repetition_penalty": 1,
        "length_penalty": -1,
    },
    "Flickr30K":{
        "max_tokens": 30, # 2, 68, 12.3
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "repetition_penalty": 1,
        "length_penalty": -1,
    },
    "NoCaps":{
        "max_tokens": 50, # 8, 154, 11.5
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "repetition_penalty": 1,
        "length_penalty": -1,
    },
    "WHOOPS-Caption":{
        "max_tokens": 60,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "repetition_penalty": 1,
        "length_penalty": -1,
    },
    "VQAv2_VAL": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
    "VQAv2_TEST": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
    "OK-VQA": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
    "VizWiz_VAL": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
    "VizWiz_TEST": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
    "TextVQA": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,  
    },
    "GQA_TESTDEV_BALANCED": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,  
    },
    "MMMU_VAL_OpenEnded": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
    "MathVista_OpenEnded": {
        "max_tokens": 10,
        "best_of": 5,
        "use_beam_search": True,
        "temperature": 0,
        "length_penalty": -1,
    },
}

choice_count_dict = {
    "MMBench_DEV_EN": [2, 3, 4],
    "MMBench_TEST_EN": [2, 3, 4],
    "ScienceQA_VAL": [2, 3, 4, 5],
    "ScienceQA_TEST": [2, 3, 4, 5],
    "SEEDBench_IMG": [4],
    "MathVista_MultiChoice": [2, 3, 4, 5, 6, 7, 8],
    "MMMU_VAL_MultiChoice": [2, 3, 4, 5, 6, 7, 9],
    "MMLU": [4],
    "PIQA_VAL": [2],
}

def yes_or_no_extraction(output):
    s = output.lower()
    if 'yes' in s and 'no' not in s:
        return 'Yes'
    if 'yes' not in s and 'no' in s:
        return 'No'
    return 'Unknown'

def response_process(examples, idxes, args):
    examples["input_index"] = idxes

    # if args.generation_mode == "text":
    #     examples["response_post"] = [response.removesuffix('Image:').removesuffix('Caption:') for response in examples["response_post"]]
    # r.removesuffix("```")

    examples["response_post"] = [re.split("\n", response, 1)[0] for response in examples["response"]]
    
    if args.dataset_type == "Y/N":
        examples["response_post"] = [yes_or_no_extraction(response) for response in examples["response_post"]]
    
    if args.dataset_type == "VQA":
        # remove tokens after \n
        if args.dataset_name in ["MMMU_VAL_OpenEnded"]:
            examples["response_post"] = [response.strip("\"").strip(".") for response in examples["response_post"]]
        elif args.dataset_name in ["MathVista_OpenEnded"]:
            examples["response_post"] = [response.strip("\"").strip(".").strip(" g").strip(" years").strip(" ml") for response in examples["response_post"]]
        else:
            examples["response_post"] = [re.split(", ", response, 1)[0] for response in examples["response_post"]]
            examples["response_post"] = [answer_processor(response) for response in examples["response_post"]]

        if args.dataset_name in ["OK-VQA"]:
            examples["response_post"] = [okvqa_stemmer.stem(response) for response in examples["response_post"]]
        if args.dataset_name in ["VizWiz_VAL", "VizWiz_TEST"]:
            examples["response_post"] = [response.replace("un answerable", "unanswerable").replace("Un answerable", "unanswerable") for response in examples["response_post"]]
        if args.dataset_name in ["GQA_TESTDEV_BALANCED"]:
            examples["response_post"] = [response.rstrip('.').lower() for response in examples["response_post"]]

    return examples

def post_process_response(args, results):
    result_dataset = Dataset.from_list(results)
    result_dataset = result_dataset.map(
        lambda examples, idxes: response_process(examples, idxes, args),
        batched=True,
        with_indices=True,
        batch_size=args.process_batch_size,
        num_proc=args.process_num_workers,
        desc="Post process responses",
    )
    
    return result_dataset

def eval_POPE(args, result_dataset):
    # https://github.com/Luodian/Otter/blob/main/pipeline/benchmarks/datasets/pope.py
    metrics = {
        "adversarial": {"TP": 0, "TN": 0, "FP": 0, "FN": 0, "yes_count": 0, "no_count": 0},
        "popular": {"TP": 0, "TN": 0, "FP": 0, "FN": 0, "yes_count": 0, "no_count": 0},
        "random": {"TP": 0, "TN": 0, "FP": 0, "FN": 0, "yes_count": 0, "no_count": 0},
        "overall": {"TP": 0, "TN": 0, "FP": 0, "FN": 0, "yes_count": 0, "no_count": 0},
    }
    for row in result_dataset:
        pred = row["response_post"]
        answer = row["answer"]
        category = row["category"]

        if pred == "Yes":
            metrics[category]["yes_count"] += 1
            metrics["overall"]["yes_count"] += 1
            pred = "yes"
        elif pred == "No":
            metrics[category]["no_count"] += 1
            metrics["overall"]["no_count"] += 1
            pred = "no"
        else:
            raise ValueError(f"Invalid response: {pred}")

        if pred == answer and pred == "yes":
            metrics[category]["TP"] += 1
            metrics["overall"]["TP"] += 1
        elif pred == answer and pred == "no":
            metrics[category]["TN"] += 1
            metrics["overall"]["TN"] += 1
        elif pred != answer and pred == "yes":
            metrics[category]["FP"] += 1
            metrics["overall"]["FP"] += 1
        else:
            metrics[category]["FN"] += 1
            metrics["overall"]["FN"] += 1
        
    for category in metrics:
        TP = metrics[category]["TP"]
        TN = metrics[category]["TN"]
        FP = metrics[category]["FP"]
        FN = metrics[category]["FN"]
        yes_count = metrics[category]["yes_count"]
        no_count = metrics[category]["no_count"]

        if TP + FP == 0:
            metrics[category]["precision"] = precision = 0
        else:
            metrics[category]["precision"] = precision = float(TP) / float(TP + FP)

        if TP + FN == 0:
            metrics[category]["recall"] = recall = 0
        else:
            metrics[category]["recall"] = recall = float(TP) / float(TP + FN)

        if precision + recall == 0:
            metrics[category]["f1"] = f1 = 0
        else:
            metrics[category]["f1"] = f1 = 2 * precision * recall / float(precision + recall)

        metrics[category]["acc"] = acc = float(TP + TN) / float(TP + TN + FP + FN)

        if yes_count + no_count == 0:
            metrics[category]["yes_ratio"] = yes_ratio = 0
        else:
            metrics[category]["yes_ratio"] = yes_ratio = yes_count / float(yes_count + no_count)

    TP = metrics["overall"]["TP"]
    TN = metrics["overall"]["TN"]
    FP = metrics["overall"]["FP"]
    FN = metrics["overall"]["FN"]
    yes_count = metrics["overall"]["yes_count"]
    no_count = metrics["overall"]["no_count"]

    metrics["overall"]["precision"] = precision = float(TP) / float(TP + FP)
    metrics["overall"]["recall"] = recall = float(TP) / float(TP + FN)
    metrics["overall"]["f1"] = f1 = 2 * precision * recall / float(precision + recall)
    metrics["overall"]["acc"] = acc = float(TP + TN) / float(TP + TN + FP + FN)
    metrics["overall"]["yes_ratio"] = yes_ratio = float(yes_count) / float(yes_count + no_count)
    return metrics

def eval_MME(args, result_dataset):
    eval_type_dict = {
        "Perception": ["existence", "count", "position", "color", "posters", "celebrity", "scene", "landmark", "artwork", "OCR"],
        "Cognition": ["commonsense_reasoning", "numerical_calculation", "text_translation", "code_reasoning"]
    }
    task_score_dict = dict()
    for eval_type, task_name_list in eval_type_dict.items():
        task_score_dict[eval_type] = dict()
        for task_name in task_name_list:
            task_score_dict[eval_type][task_name] = dict()
            acc_plus_correct_num = 0
            print(result_dataset.column_names)
            task_dataset = result_dataset.filter(lambda x: x["category"] == task_name)
            for i in range(0, task_dataset.num_rows, 2):
                if task_dataset[i]["response_post"] == task_dataset[i]["answer"] and task_dataset[i+1]["response_post"] == task_dataset[i+1]["answer"]:
                    acc_plus_correct_num += 1
            task_score_dict[eval_type][task_name]["acc"] = sum([1 if task_dataset[i]["response_post"] == task_dataset[i]["answer"] else 0 for i in range(task_dataset.num_rows)]) / task_dataset.num_rows
            task_score_dict[eval_type][task_name]["acc_plus"] = acc_plus_correct_num / int(task_dataset.num_rows / 2)
            task_score_dict[eval_type][task_name]["score"] = task_score_dict[eval_type][task_name]["acc"] * 100 + task_score_dict[eval_type][task_name]["acc_plus"] * 100
            task_score_dict[eval_type][task_name]["unknown"] = sum([1 if task_dataset[i]["response_post"] == "Unknown" else 0 for i in range(task_dataset.num_rows)]) / task_dataset.num_rows
            task_score_dict[eval_type][task_name]['count'] = task_dataset.num_rows
    scores = dict()
    scores["eval_type_scores"] = dict()
    for eval_type in eval_type_dict:
        scores["eval_type_scores"][eval_type] = {}
        scores["eval_type_scores"][eval_type]["scores"] = sum([task_score_dict[eval_type][task_name]["score"] for task_name in eval_type_dict[eval_type]])
        scores["eval_type_scores"][eval_type]["count"] = sum([task_score_dict[eval_type][task_name]["count"] for task_name in eval_type_dict[eval_type]])
        scores["eval_type_scores"][eval_type]["unknown"] = sum([task_score_dict[eval_type][task_name]["unknown"] * task_score_dict[eval_type][task_name]["count"] for task_name in eval_type_dict[eval_type]]) / scores["eval_type_scores"][eval_type]["count"]
    
    scores["total_scores"] = sum([scores["eval_type_scores"][eval_type]["scores"] for eval_type in eval_type_dict])
    scores["total_unknown"] = sum([scores["eval_type_scores"][eval_type]["unknown"] * scores["eval_type_scores"][eval_type]["count"] for eval_type in eval_type_dict]) / result_dataset.num_rows
    scores["task_scores"] = task_score_dict
    return scores

def eval_Winoground_YN(args, result_dataset):
    # https://twitter.com/ChengleiSi/status/1731047075528561119
    ## define the metrics functions
    def text_correct(result):
        return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]

    def image_correct(result):
        return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]

    def group_correct(result):
        return image_correct(result) and text_correct(result)
    
    text_correct_count = 0
    image_correct_count = 0
    group_correct_count = 0
    total = 0
    result_keys = ["c0_i0", "c1_i0", "c0_i1", "c1_i1"]
    
    for i in range(0, result_dataset.num_rows, 4):
        result = {}
        for j in range(4):
            row = result_dataset[i+j]
            pred = row["response_post"]
            # answer = row["answer"]
            if pred.lower() == 'yes':
                result[result_keys[j]] = 1.0
            else:
                result[result_keys[j]] = 0.0
        text_correct_count += 1 if text_correct(result) else 0
        image_correct_count += 1 if image_correct(result) else 0
        group_correct_count += 1 if group_correct(result) else 0
        total += 1
    
    scores = {
        "text_correct_count": text_correct_count,
        "image_correct_count": image_correct_count,
        "group_correct_count": group_correct_count,
        "total": total,
        "text score": text_correct_count / total * 100,
        "image score": image_correct_count / total * 100,
        "group score": group_correct_count / total * 100,
    }
    return scores            

def eval_HallusionBench(args, result_dataset):
    def calc_fAcc(data):
        res = defaultdict(list)
        lt = len(data)
        for i in range(lt):
            line = data[i]
            res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score'])
        return np.mean([np.all(x) for x in res.values()]) * 100
    
    def calc_qAcc(data):
        res = defaultdict(list)
        lt = len(data)
        for i in range(lt):
            line = data[i]
            res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score'])
        return np.mean([np.all(x) for x in res.values()]) * 100
    
    def calc_aAcc(data):
        return np.mean(data['score']) * 100
   
    def transform_dataset(example):
        example['set_id'] = example['index'].split('_')[3]
        example['figure_id'] = example['index'].split('_')[4]
        example['question_id'] = example['index'].split('_')[5]
        example['score'] = example['response_post'] == example['answer']
        return example

    result_dataset = result_dataset.map(transform_dataset)
    scores = {}
    scores['Overall'] = {}
    scores['Overall']['aAcc'] = calc_aAcc(result_dataset)
    scores['Overall']['fAcc'] = calc_fAcc(result_dataset)
    scores['Overall']['qAcc'] = calc_qAcc(result_dataset)
    
    for key in ['category', 'l2-category']:
        cates = list(set(result_dataset[key]))
        for c in cates:
            sub = result_dataset.filter(lambda x: x[key] == c)
            scores[c] = {}
            scores[c]['aAcc'] = calc_aAcc(sub)
            scores[c]['fAcc'] = calc_fAcc(sub)
            scores[c]['qAcc'] = calc_qAcc(sub)
    return scores

def eval_WHOOPS_Caption(args, results_json):
    # https://colab.research.google.com/drive/1av7JdDk005qQL6WdAVL0kFlah7VXV0Md?usp=sharing
    from pycocoevalcap.eval import COCOEvalCap
    from pycocotools.coco import COCO
    from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
    from pycocoevalcap.cider.cider import Cider
    from pycocoevalcap.bleu.bleu import Bleu

    image_ids = [x["image_id"] for x in results_json]
    predictions, references = {}, {}
    for image_id, result in zip(image_ids, results_json):
        predictions[image_id] = [{"image_id": image_id, "caption": result["caption"]}]
        # references[image_id] = [{"image_id": image_id, "caption": gold} for gold in result["gold_answer"]]
        references[image_id] = [{"image_id": image_id, "caption": gold.strip()} for gold in result["gold_answer"]]
    
    tokenizer = PTBTokenizer()
    predictions  = tokenizer.tokenize(predictions)
    references = tokenizer.tokenize(references)
    scorers = [
        (Bleu(4), "Bleu_4"),
        (Cider(), "CIDEr"),
    ]
    scores = {}
    for scorer, method in scorers:
        print(f'computing {scorer.method()} score...')
        score, _ = scorer.compute_score(references, predictions)
        if isinstance(score, list):
            scores[method] = score[3]
        else:
            scores[method] = score
    return scores

def eval_score(args, results):
    from Evaluation import dataset_name_split_mapping, dataset_name_answer_mapping, dataset_name_image_id_mapping

    if args.from_hf:
        origin_path = os.path.join(args.dataset_dir, args.dataset_name)
        dataset_split = dataset_name_split_mapping[args.dataset_name]
    else:
        origin_path = os.path.join(args.dataset_dir, args.dataset_name, 'datasets')
        dataset_split = 'train'
    
    result_dataset = post_process_response(args, results)
    os.makedirs(os.path.join(args.result_dir, f"{args.dataset_name}"), exist_ok=True)
    result_path = os.path.join(args.result_dir, f"{args.dataset_name}/{args.prompt_setting}_{args.template_index}_{'_'.join(args.checkpoint_path.split('/')[-2:])}.csv")
    score_path = f"{result_path[:-4]}_scores.json"
    if args.dataset_type == "Y/N":
        origin_dataset = load_from_disk(origin_path)[dataset_split]
        assert result_dataset.num_rows == origin_dataset.num_rows
        result_dataset = result_dataset.add_column("answer", origin_dataset[dataset_name_answer_mapping[args.dataset_name]])
        if args.dataset_name not in ["Winoground-YN"]:
            result_dataset = result_dataset.add_column('category', origin_dataset['category'])
        if args.dataset_name == "MME":
            assert result_dataset.num_rows % 2 == 0
            scores = eval_MME(args, result_dataset)
        elif args.dataset_name == "POPE":
            scores = eval_POPE(args, result_dataset)
        elif args.dataset_name == "Winoground-YN":
            assert result_dataset.num_rows % 4 == 0
            scores = eval_Winoground_YN(args, result_dataset)
        elif args.dataset_name == "HallusionBench":
            result_dataset = result_dataset.add_column("l2-category", origin_dataset["l2-category"])
            result_dataset = result_dataset.add_column("index", origin_dataset["index"])
            scores = eval_HallusionBench(args, result_dataset)
        # result_dataset = result_dataset.remove_columns('category')
        # save result
        result_dataset.to_csv(result_path)
        print(f"Saved results to {result_path}")
        with open(score_path, "w", encoding="utf-8-sig") as f:
            json.dump(scores, f, indent=4, ensure_ascii=False)
        print(f"Saved scores to {score_path}")
        print("########################################")
        if args.dataset_name == "MME":
            print(f"{args.dataset_name}\t{args.prompt_setting}_{args.template_index}\nPerception\tCognition\tTotal\n{scores['eval_type_scores']['Perception']['scores']:.2f}\t{scores['eval_type_scores']['Cognition']['scores']:.2f}\t{scores['total_scores']:.2f}")
        elif args.dataset_name == "POPE":
            print(f"{args.dataset_name}\t{args.prompt_setting}_{args.template_index}\nAdversarial\tPopular\tRandom\tOverall\tYes Ratio\n{scores['adversarial']['f1'] * 100:.2f}\t{scores['popular']['f1'] * 100:.2f}\t{scores['random']['f1'] * 100:.2f}\t{scores['overall']['f1'] * 100:.2f}\t{scores['overall']['yes_ratio'] * 100:.2f}")
        elif args.dataset_name == "Winoground-YN":
            print(f"{args.dataset_name}\t{args.prompt_setting}_{args.template_index}\nText\tImage\tGroup\n{scores['text score']:.2f}\t{scores['image score']:.2f}\t{scores['group score']:.2f}")
        elif args.dataset_name == "HallusionBench":
            print(f"{args.dataset_name}\t{args.prompt_setting}_{args.template_index}\nOverall aAcc\tfAcc\tqAcc\tAvg\n{scores['Overall']['aAcc']:.2f}\t{scores['Overall']['fAcc']:.2f}\t{scores['Overall']['qAcc']:.2f}\t{(scores['Overall']['aAcc'] + scores['Overall']['fAcc'] + scores['Overall']['qAcc']) / 3:.2f}")
        else:
            raise NotImplementedError
        print("########################################")
    
    elif args.dataset_type == "multi-choice":
        mmbench_eval(args, origin_path, result_path, score_path, dataset_split, result_dataset=result_dataset)
    
    elif args.dataset_type == "Caption":
        origin_dataset = load_from_disk(origin_path)[dataset_split]
        assert result_dataset.num_rows == origin_dataset.num_rows
        result_dataset = result_dataset.add_column("image_id", origin_dataset[dataset_name_image_id_mapping[args.dataset_name]])
        if args.dataset_name in ["Flickr30K"]:
            result_dataset = result_dataset.map(lambda x: {"image_id": f"flickr30k-images/{x['image_id']}"})
        result_dataset = result_dataset.add_column("gold_answer", origin_dataset[dataset_name_answer_mapping[args.dataset_name]])
        result_dataset = result_dataset.rename_column("response_post", "caption")
        result_dataset = result_dataset.remove_columns("input_index")
        result_path = f"{result_path[:-4]}.json"
        results_json = []
        for i in range(result_dataset.num_rows):
            results_json.append({
                "image_id": result_dataset[i]["image_id"],
                "caption": result_dataset[i]["caption"],
                "response": result_dataset[i]["response"],
                "gold_answer": result_dataset[i]["gold_answer"],
            })
            # results_json[-1]["gold_answer"] = result_dataset[i]["gold_answer"]
        with open(result_path, "w") as f:
            json.dump(results_json, f, indent=4)
        print(f"Saved results to {result_path}")
        if args.dataset_name in ["COCO", "Flickr30K", "NoCaps", "WHOOPS-Caption"]:
            if args.dataset_name in ["WHOOPS-Caption"]:
                scores = eval_WHOOPS_Caption(args, results_json)
            else:
                scores = coco_caption_eval(os.path.join(args.result_dir, f"{args.dataset_name}", "annotations"), result_path, dataset_split, dataset_name=args.dataset_name)
            with open(score_path, "w", encoding="utf-8-sig") as f:
                json.dump(scores, f, indent=4, ensure_ascii=False)
            print(f"Saved scores to {score_path}")
            print("########################################")
            print(f"{args.dataset_name}_{args.prompt_setting}_{args.template_index}\nCider\tBleu_4\n{scores['CIDEr'] * 100.0:.2f}\t{scores['Bleu_4'] * 100.0:.2f}")
            print("########################################")
            
    elif args.dataset_type == "VQA":
        if args.debug:
            origin_dataset = load_from_disk(origin_path)[dataset_split].select(range(20000))
        else:
            origin_dataset = load_from_disk(origin_path)[dataset_split]
        assert result_dataset.num_rows == origin_dataset.num_rows
        result_dataset = result_dataset.add_column("question", origin_dataset["question"])
        
        if args.dataset_name in ["VizWiz_VAL", "VizWiz_TEST"]:
            result_dataset = result_dataset.add_column("filename", origin_dataset["filename"])
        
        if args.dataset_name in ["VizWiz_VAL", "VizWiz_TEST", "GQA_TESTDEV_BALANCED", "MMMU_VAL_OpenEnded", "MMMU_TEST_OpenEnded"]:
            result_dataset = result_dataset.add_column("question_id", origin_dataset["id"])
        elif args.dataset_name in ["MathVista_OpenEnded"]:
            result_dataset = result_dataset.add_column("question_id", origin_dataset["pid"])
        else:
            result_dataset = result_dataset.add_column("question_id", origin_dataset["question_id"])
        
        if args.dataset_name not in ["VQAv2_TEST", "VizWiz_TEST"]:
            result_dataset = result_dataset.add_column("gold_answer", origin_dataset[dataset_name_answer_mapping[args.dataset_name]])
        
        result_dataset = result_dataset.rename_column("response_post", "answer")
        result_dataset = result_dataset.remove_columns("input_index")
        result_path = f"{result_path[:-4]}.json"
        results_json = []
        for i in range(result_dataset.num_rows):
            results_json.append({
                "question_id": result_dataset[i]["question_id"],
                "question": result_dataset[i]["question"],
                "response": result_dataset[i]["response"],
                "answer": result_dataset[i]["answer"],
            })
            if args.dataset_name not in ["VQAv2_TEST", "VizWiz_TEST"]:
                results_json[-1]["gold_answer"] = result_dataset[i]["gold_answer"]
            if args.dataset_name in ["VizWiz_TEST"]:
                results_json[-1]["image"] = result_dataset[i]["filename"]
        if args.dataset_name == "VQAv2_TEST":
            # https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json
            with open(os.path.join(args.dataset_dir, 'results', args.dataset_name, 'annotations', "vqa_test.json"), "r") as f:
                vqa_test = json.load(f)
            vqa_test_question_ids = set([q["question_id"] for q in vqa_test])
            results_json_question_ids = set([q["question_id"] for q in results_json])
            append_question_ids = list(vqa_test_question_ids - results_json_question_ids)
            for question_id in append_question_ids:
                results_json.append({
                    "question_id": question_id,
                    "question": "",
                    "response": "",
                    "answer": "",
                })

        with open(result_path, "w") as f:
            json.dump(results_json, f, indent=4)
        print(f"Saved results to {result_path}")
        if args.dataset_name in ["VQAv2_VAL", "OK-VQA", "TextVQA", "VizWiz_VAL", "GQA_TESTDEV_BALANCED", "MMMU_VAL_OpenEnded", "MathVista_OpenEnded"]:
            if args.dataset_name in ["GQA_TESTDEV_BALANCED", "MMMU_VAL_OpenEnded", "MathVista_OpenEnded"]:
                correct = [1 if answer == gold_answer else 0 for answer, gold_answer in zip(result_dataset["answer"], result_dataset["gold_answer"])]
                scores = {"overall": sum(correct) / len(correct) * 100, "correct": sum(correct), "total": len(correct)}
            else:
                scores = compute_vqa_accuracy(os.path.join(args.result_dir, f"{args.dataset_name}", "annotations"), result_path, dataset_split, dataset_name=args.dataset_name)
            with open(score_path, "w", encoding="utf-8-sig") as f:
                json.dump(scores, f, indent=4, ensure_ascii=False)
            print(f"Saved scores to {score_path}")
            print("########################################")
            print(f"{args.dataset_name}_{args.prompt_setting}_{args.template_index}\nOverall\n{scores['overall']:.2f}")
            if args.dataset_name in ["MMMU_VAL_OpenEnded", "MathVista_OpenEnded"]:
                print(f"\nCorrect\tTotal\n{scores['correct']}\t{scores['total']}")
            print("########################################")

    else:
        raise NotImplementedError

    # return result_dataset, scores