import os
import argparse
import json
import re

from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--annotation-file", type=str)
    parser.add_argument("--result-file", type=str)
    parser.add_argument("--result-dir", type=str)
    return parser.parse_args()


def prompt_processor(prompt):
    if prompt.startswith("OCR tokens: "):
        pattern = r"Question: (.*?) Short answer:"
        match = re.search(pattern, prompt, re.DOTALL)
        question = match.group(1)
    elif "Reference OCR token: " in prompt and len(prompt.split("\n")) == 3:
        if prompt.startswith("Reference OCR token:"):
            question = prompt.split("\n")[1]
        else:
            question = prompt.split("\n")[0]
    elif len(prompt.split("\n")) == 2:
        question = prompt.split("\n")[0]
    else:
        assert False

    return question.lower()


def eval_single(annotation_file, result_file):
    experiment_name = os.path.splitext(os.path.basename(result_file))[0]
    print(experiment_name)
    annotations = json.load(open(annotation_file))["data"]
    annotations = {
        (annotation["image_id"], annotation["question"].lower()): annotation
        for annotation in annotations
    }
    results = [json.loads(line) for line in open(result_file)]

    pred_list = []
    for result in results:
        annotation = annotations[
            (result["question_id"], prompt_processor(result["prompt"]))
        ]
        pred_list.append(
            {
                "pred_answer": result["text"],
                "gt_answers": annotation["answers"],
            }
        )

    evaluator = TextVQAAccuracyEvaluator()
    print(
        "Samples: {}\nAccuracy: {:.2f}%\n".format(
            len(pred_list), 100.0 * evaluator.eval_pred_list(pred_list)
        )
    )


if __name__ == "__main__":
    args = get_args()

    if args.result_file is not None:
        eval_single(args.annotation_file, args.result_file)

    if args.result_dir is not None:
        for result_file in sorted(os.listdir(args.result_dir)):
            if not result_file.endswith(".jsonl"):
                print(f"Skipping {result_file}")
                continue
            eval_single(
                args.annotation_file, os.path.join(args.result_dir, result_file)
            )
