# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re


# def extract_solution(solution_str, method='strict'):
#     assert method in ['strict', 'flexible']

#     if method == 'strict':
#         # this also tests the formatting of the model
#         solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
#         if solution is None:
#             final_answer = None
#         else:
#             final_answer = solution.group(0)
#             final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '')
#     elif method == 'flexible':
#         answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
#         final_answer = None
#         if len(answer) == 0:
#             # no reward is there is no answer
#             pass
#         else:
#             invalid_str = ['', '.']
#             # find the last number that is not '.'
#             for final_answer in reversed(answer):
#                 if final_answer not in invalid_str:
#                     break
#     return final_answer

import json
import re
import pandas as pd
import requests
import json
import argparse
import re
import numpy as np
import string
import subprocess
import time


def maybe_normalize_float(span: str):
    if (
        span
        and (
            re.match(r"^[+-][0-9]+[.]?[0-9]*$", span)
            or (re.match(r"^[0-9]*[.]?[0-9]*$", span))
        )
        and span != "."
    ):
        return str(float(span))
    else:
        return span


def maybe_normalize_number(text: str) -> str:
    units = [
        "zero",
        "one",
        "two",
        "three",
        "four",
        "five",
        "six",
        "seven",
        "eight",
        "nine",
        "ten",
        "eleven",
        "twelve",
        "thirteen",
        "fourteen",
        "fifteen",
        "sixteen",
        "seventeen",
        "eighteen",
        "nineteen",
    ]
    for index, unit in enumerate(units):
        if text == unit:
            return str(float(index))
    return text


def remove_punc(text: str) -> str:
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)


def remove_articles(text: str) -> str:
    return re.sub(r"\b(a|an|the)\b", " ", text)


def check_overlap(str1, str2):
    str1 = remove_punc(str1.replace(" ", ""))
    str2 = remove_punc(str2.replace(" ", ""))
    if str1 in str2 or str2 in str1:
        return True
    count = 0
    for letter in str1:
        if letter != "0":
            if letter in str2:
                count += 1
    if len(str1) == 0 or len(str2) == 0:
        return True
    else:
        return True if count / len(str1) > 0.5 or count / len(str2) > 0.5 else False


def get_answer(pred):
    match = re.search(r"(The|the) answer is ([^\.]+)\.$", pred)
    if match:
        return match.group(2).strip('"'), True
    return pred, False


def eval_ex_match(pred, gold_result):
    pred = pred.lower()
    gold_result = str(gold_result).lower()
    compare_1 = pred.replace(".", "").replace(",", "").replace("%", "")
    if len(compare_1) == 0:
        compare_1 = " "
    compare_2 = gold_result.replace(".", "").replace(",", "").replace("%", "")
    if len(compare_2) == 0:
        compare_2 = " "
    if compare_1[0] == "-":
        compare_1 = compare_1[1:]
    if compare_2[0] == "-":
        compare_2 = compare_2[1:]
    if (
        compare_1.isdigit() == True
        and compare_2.isdigit() == True
        and pred.count(".") < 2
        and gold_result != "-"
    ):
        if pred[-1] == ".":
            pred = pred[0 : len(pred) - 1]
        gold_result = gold_result.replace(",", "").replace("%", "")
        pred = pred.replace(",", "").replace("%", "")
        pred = abs(float(pred))
        gold_result = abs(float(gold_result))
        if abs(pred - gold_result) < 0.01:
            return True, str(pred), str(gold_result)
        else:
            return False, str(pred), str(gold_result)

    if " and " in pred and "|" in gold_result:
        pred = pred.replace(" and ", ", ")

    pred = [span.strip() for span in pred.split(", ")]

    if "|" in gold_result:
        gold_result = [span.strip() for span in gold_result.split("|")]
    else:
        gold_result = [span.strip() for span in gold_result.split(", ")]

    pred = [
        maybe_normalize_number(remove_punc(remove_articles(span.strip())))
        for span in pred
    ]
    gold_result = [
        maybe_normalize_number(remove_punc(remove_articles(span.strip())))
        for span in gold_result
    ]

    clean_float = True
    if clean_float:
        pred = [maybe_normalize_float(span) for span in pred]
        gold_result = [maybe_normalize_float(span) for span in gold_result]
    indicater = False
    for item in pred:
        if item in gold_result:
            indicater = True

    if sorted(pred) == sorted(gold_result):
        indicater = True
    return sorted(pred) == sorted(gold_result), sorted(pred), sorted(gold_result)


def match_all(data, option):
    if len(data["label"]) == len(data[option + " prediction"]):
        flag = True
        for i in range(len(data["label"])):
            if_match, pred, label = eval_ex_match(
                str(data["label"][i]), data[option + " prediction"][i]
            )
            if if_match == False:
                flag = False
        if flag == True:
            return True
    else:
        return False


def find_matching_brace(s, start):
    """
    Given a string s and the position (start) of a '{', return the index of the corresponding matching '}'.
    If no matching brace is found, return -1.
    """
    count = 0
    for i in range(start, len(s)):
        if s[i] == '{':
            count += 1
        elif s[i] == '}':
            count -= 1
            if count == 0:
                return i
    return -1

def extract_boxed_content(response_content):
    """
    Extract the complete content inside the last occurrence of \boxed{...} in response_content,
    properly handling nested braces. Return None if not found.
    """
    keyword = r"\boxed{"
    idx = response_content.rfind(keyword)
    if idx == -1:
        print("No '\\boxed{}' found.")
        return None
    start = idx + len(keyword) - 1  
    if response_content[start] != '{':
        print("Start brace '{' not found after '\\boxed{'")
        return None
    end = find_matching_brace(response_content, start)
    if end == -1:
        print("No matching '}' found.")
        return None
    # Return the inner content without the outer braces
    return response_content[start+1:end].strip()

def remove_latex_text(s):
    """
    Remove all occurrences of the \text{...} pattern from the string s (supports nested braces).
    """
    keyword = r"\text{"
    while keyword in s:
        idx = s.find(keyword)
        start = idx + len(keyword) - 1  # Should be '{'
        if start >= len(s) or s[start] != '{':
            break
        end = find_matching_brace(s, start)
        if end == -1:
            break
        inner = s[start+1:end].strip()
        s = s[:idx] + inner + s[end+1:]
    return s

def extract_boxed_answer(response_content):
    """
    Combine the above two functions: first extract the content within \boxed{...},
    then remove all \text{...} patterns from it, and finally return the cleaned answer.
    """
    # print("response_content:", response_content)
    content = extract_boxed_content(response_content)
    if content is None:
        print("No boxed content extracted.")
        return None
    content = remove_latex_text(content)
    content = content.replace(r"\%", "%")
    return content.strip()


def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
    """The scoring function for GSM8k.

    Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.

    Args:
        solution_str: the solution text
        ground_truth: the ground truth
        method: the method to extract the solution, choices are 'strict' and 'flexible'
        format_score: the score for the format
        score: the score for the correct answer
    """
    # answer = extract_solution(solution_str=solution_str, method=method)
    # if answer is None:
    #     return 0
    # else:
    #     if answer == ground_truth:
    #         return score
    #     else:
    #         return format_score
    
    answer = extract_boxed_answer(solution_str)
    print("gsm8k extracted content:", answer)
    print("gsm8k ground truth:", ground_truth)
    if answer is None:
        return 0
    else:
        is_match, spred, sgold = eval_ex_match(answer, ground_truth)
        if is_match:
            print("score:", score)
            return score
        else:
            print("format score:", format_score)
            return format_score