import re
import os
import json
import ipdb
import math
import numpy as np


def levenshtein_distance(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2 + 1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]


def vqa_evaluation(predict, answers):
    score = 0
    if type(answers) == list:
        for j in range(len(answers)):
            if isinstance(answers[j], (int, float)):
                answers[j] = str(answers[j])
            try:
                answer = answers[j].lower().strip().replace("\n", " ")
            except:
                ipdb.set_trace()
            if isinstance(predict, (int, float)):
                predict = str(predict)
            predict = predict.lower().strip().replace("\n", " ")
            if len(answer.split()) < 5:
                if answer in predict:
                    score = 1
            else:
                dist = levenshtein_distance(predict, answer)
                length = max(len(predict), len(answer))
                ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
                ANLS_value = 1 - ANLS_value

                if ANLS_value >= 0.5 and ANLS_value > score:
                    score = ANLS_value

    else:
        answers = answers.lower().strip().replace("\n", " ")
        predict = predict.lower().strip().replace("\n", " ")
        if len(answers.split()) < 5:
            if answers in predict:
                score = 1
        else:
            dist = levenshtein_distance(predict, answers)
            length = max(len(predict), len(answers))
            ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
            ANLS_value = 1 - ANLS_value

            if ANLS_value >= 0.5 and ANLS_value > score:
                score = ANLS_value

    return score


def cn_vqa_evaluation(predict, answers):
    score = 0
    if type(answers) == list:
        for j in range(len(answers)):
            if isinstance(answers[j], (int, float)):
                answers[j] = str(answers[j])
            try:
                answer = answers[j].lower().strip().replace("\n", " ").replace(" ", "")
            except:
                ipdb.set_trace()
            if isinstance(predict, (int, float)):
                predict = str(predict)
            predict = predict.lower().strip().replace("\n", " ").replace(" ", "")
            if len(answer.split(",")) < 4:
                if answer in predict:
                    score = 1
            else:
                dist = levenshtein_distance(predict, answer)
                length = max(len(predict), len(answer))
                ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
                ANLS_value = 1 - ANLS_value

                if ANLS_value >= 0.5 and ANLS_value > score:
                    score = ANLS_value

    else:
        answers = answers.lower().strip().replace("\n", " ").replace(" ", "")
        predict = predict.lower().strip().replace("\n", " ").replace(" ", "")
        if len(answers.split(",")) < 4:
            if answers in predict:
                score = 1
            else:
                dist = levenshtein_distance(predict, answers)
                length = max(len(predict), len(answers))
                ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
                ANLS_value = 1 - ANLS_value

                if ANLS_value >= 0.5 and ANLS_value > score:
                    score = ANLS_value

    return score


def vqa_evaluation_case_sensitive(predict, answers):
    score = 0
    if type(answers) == list:
        for j in range(len(answers)):
            if isinstance(answers[j], (int, float)):
                answers[j] = str(answers[j])
            try:
                answer = answers[j].strip().replace("\n", " ")
            except:
                ipdb.set_trace()
            predict = predict.strip().replace("\n", " ")
            if len(answer.split()) < 5:
                if answer in predict:
                    score = 1
            else:
                dist = levenshtein_distance(predict, answer)
                length = max(len(predict), len(answer))
                ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
                ANLS_value = 1 - ANLS_value

                if ANLS_value >= 0.5 and ANLS_value > score:
                    score = ANLS_value

    else:
        answers = answers.strip().replace("\n", " ")
        predict = predict.strip().replace("\n", " ")
        if len(answers.split()) < 5:
            if answers in predict:
                score = 1
            else:
                dist = levenshtein_distance(predict, answers)
                length = max(len(predict), len(answers))
                ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
                ANLS_value = 1 - ANLS_value

                if ANLS_value >= 0.5 and ANLS_value > score:
                    score = ANLS_value

    return score


def extract_first_number(string):
    match = re.search(r'\d+', string)
    if match:
        return int(match.group())
    return None


def counting_evaluation(predict, answers, eval_method):
    score = 0

    if isinstance(predict, str):
        predict_processed = predict.lower().strip().replace("\n", " ")
    elif math.isnan(predict):
        return 0
    else:
        predict_processed = int(predict)
    if type(answers) == list:
        temp_score = 0
        for j in range(len(answers)):
            if isinstance(answers[j], (int, float)):
                answers[j] = str(answers[j])
            answer = answers[j].lower().strip().replace("\n", " ")
            if eval_method == "exact match":
                if answer in predict:
                    score = 1
                else:
                    score = 0
            elif eval_method == "regression":
                predict_number = extract_first_number(predict_processed)
                if predict_number:

                    answer = int(answer)

                    if predict_number <= 0 or predict_number >= 2 * answer:
                        score = 0
                    else:
                        iou = 1 - abs(predict_number - answer) / answer
                        if iou > 0.5:
                            score = iou
                        else:
                            score = 0
                else:
                    score = 0
            if score > temp_score:
                temp_score = score
        score = temp_score

    else:
        answer = answers.lower().strip().replace("\n", " ")
        predict = predict.lower().strip().replace("\n", " ")
        if eval_method == "exact match":
            if answer in predict:
                score = 1
            else:
                score = 0
        elif eval_method == "regression":
            predict = extract_first_number(predict)
            if predict:
                answer = int(answer)
                if predict <= 0 or predict >= 2 * answer:
                    score = 0
                else:
                    iou = 1 - abs(predict - answer) / answer

                    if iou > 0.5:
                        score = iou
                    else:
                        score = 0
            else:
                score = 0
    return score


def math_expression_evaluation(predict, answers):
    score = 0
    if type(answers) == list:
        for j in range(len(answers)):
            answer = answers[j].strip().replace("\n", " ").replace(" ", "")
            predict = predict.strip().replace("\n", " ").replace(" ", "")
            if answer in predict:
                score = 1
    else:
        answers = answers.strip().replace("\n", " ").replace(" ", "")
        predict = predict.strip().replace("\n", " ").replace(" ", "")
        if answers in predict:
            score = 1
    return score


def remove_text_tags(latex_str):
    """
    Removes LaTeX \text{...} tags while keeping their content.

    :param latex_str: A string containing LaTeX expressions
    :return: The processed string with \text{...} tags removed
    """

    pattern = r'\\text\{([^{}]*)\}'

    processed_str = re.sub(pattern, r'\1', latex_str)

    return processed_str


def cn_math_expression_evaluation(predict, answers):
    score = 0

    assert len(answers) == 1
    answers = [remove_text_tags(answers[0])]
    predict = remove_text_tags(predict)

    if type(answers) == list:
        for j in range(len(answers)):
            answer = answers[j].strip().replace("\n", " ").replace(" ", "")
            predict = predict.strip().replace("\n", " ").replace(" ", "")
            if answer in predict:
                score = 1
    else:
        answers = answers.strip().replace("\n", " ").replace(" ", "")
        predict = predict.strip().replace("\n", " ").replace(" ", "")
        if answers in predict:
            score = 1
    return score


if __name__ == "__main__":
    test_predict = "apple pie and banana"
    test_answers = ["apple", "banana pie", "apple pie and orange"]

    vqa_score = vqa_evaluation(test_predict, test_answers)
    print(f"VQA evaluation score for predict '{test_predict}' and answers {test_answers}: {vqa_score}")