import datetime
import json
import os
import re
import statistics

from loguru import logger as eval_logger

import lmms_eval.tasks._task_utils.file_utils as file_utils
from lmms_eval.tasks._task_utils.vqa_eval_metric import EvalAIAnswerProcessor


def vqav2_doc_to_visual(doc):
    return [doc["image"].convert("RGB")]


def vqav2_process_results(doc, result):
    eval_ai_processor = EvalAIAnswerProcessor()
    assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
    resAns = eval_ai_processor(result[0])
    accuracy = 0

    if "answers" in doc and doc["answers"] is not None:
        for ansDic in doc["answers"]:
            ansDic["answer"] = ansDic["answer"].replace("\n", " ")
            ansDic["answer"] = ansDic["answer"].replace("\t", " ")
            ansDic["answer"] = ansDic["answer"].strip()
        gtAcc = []
        gtAnswers = [ans["answer"] for ans in doc["answers"]]

        if len(set(gtAnswers)) > 1:
            for ansDic in doc["answers"]:
                ansDic["answer"] = eval_ai_processor.process_punctuation(ansDic["answer"])
                ansDic["answer"] = eval_ai_processor.process_digit_article(ansDic["answer"])
            resAns = eval_ai_processor.process_punctuation(resAns)
            resAns = eval_ai_processor.process_digit_article(resAns)

        for gtAnsDatum in doc["answers"]:
            otherGTAns = [item for item in doc["answers"] if item != gtAnsDatum]
            matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
            acc = min(1, float(len(matchingAns)) / 3)
            gtAcc.append(acc)
        accuracy = statistics.mean(gtAcc)

    return {
        "exact_match": accuracy,
        "submission": {
            "question_id": doc["question_id"],
            "answer": resAns,
        },
    }


def vqav2_process_results_test(doc, result):
    res = vqav2_process_results(doc, result)
    return {
        "submission": res["submission"],
    }


def vqav2_process_results_val(doc, result):
    res = vqav2_process_results(doc, result)
    return {
        "exact_match": res["exact_match"],
    }


def vqav2_doc_to_text(doc, lmms_eval_specific_kwargs=None):
    if lmms_eval_specific_kwargs is None:
        lmms_eval_specific_kwargs = {}
    pre_prompt = ""
    post_prompt = ""
    if "pre_prompt" in lmms_eval_specific_kwargs:
        pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
    if "post_prompt" in lmms_eval_specific_kwargs:
        post_prompt = lmms_eval_specific_kwargs["post_prompt"]
    return f"{pre_prompt}{doc['question']}{post_prompt}"


def vqav2_aggregate_submissions(results, args):
    now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    submission_file_name = f"vqav2-test-submission-{now_date_time}.json"
    path = file_utils.generate_submission_file(submission_file_name, args)
    with open(path, "w") as f:
        json.dump(results, f)
    eval_logger.info(f"Submission file saved to {path}")
