import argparse
import glob
import json
import re

# This can help resolve an import error of an mmf dependency that is not needed.
try:
    from mmf.utils.m4c_evaluators import TextVQAAccuracyEvaluator
except ModuleNotFoundError:
    from mmf.utils.m4c_evaluators import TextVQAAccuracyEvaluator


def merge_input_files(input_path):
    """Merge input files to a format compatible with the evaluator."""
    output_file_path = input_path + "-TextVQA-merged.json"

    pattern = input_path + "-TextVQA-[0-9].*jsonl"
    input_file_paths = glob.glob(pattern)

    results = []

    for input_file_path in input_file_paths:
        with open(input_file_path, "r") as input_file:
            for line in input_file:
                res = json.loads(line)
                results.append(res)

    with open(output_file_path, "w") as output_file:
        json.dump(results, output_file)

    return output_file_path


# Note: This is based on https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/eval/eval_textvqa.py#L17
# and slightly modified.
def prompt_processor(prompt):
    if prompt.startswith('OCR tokens: '):
        pattern = r"Question: (.*?) Short answer:"
        match = re.search(pattern, prompt, re.DOTALL)
        question = match.group(1)
    elif "Reference OCR token: " in prompt and len(prompt.split("\n")) == 3:
        if prompt.startswith("Reference OCR token:"):
            question = prompt.split("\n")[1]
        else:
            question = prompt.split("\n")[0]
    elif len(prompt.split("\n")) == 2:
        question = prompt.split("\n")[0]
    else:
        raise RuntimeError("unexpected prompt format")

    return question.lower()


# Note: This is based on https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/eval/eval_textvqa.py#L35
# and slightly modified.
def evaluate(result_file_path, groundtruth_path):
    with open(groundtruth_path) as groundtruth_file:
        groundtruth = json.load(groundtruth_file)["data"]

    groundtruth = {(gt["image_id"]): gt["answers"] for gt in groundtruth}

    with open(result_file_path, "r") as result_file:
        results = json.load(result_file)

    predictions = []
    for result in results:
        gt_answers = groundtruth[(result["sample_id"])]
        predictions.append({"pred_answer": result["text"], "gt_answers": gt_answers})

    evaluator = TextVQAAccuracyEvaluator()
    print(
        'Samples: {}\nAccuracy: {:.2f}%\n'.format(
            len(predictions), 100.0 * evaluator.eval_pred_list(predictions)
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-path', type=str, help="Path to input file(s)")
    parser.add_argument('--groundtruth-path', type=str, help="Path to groundtruth file")
    args = parser.parse_args()

    result_file_path = merge_input_files(args.input_path)

    evaluate(result_file_path, args.groundtruth_path)
