import re
import sys
import string
from typing import Union, List
from collections import Counter
from transformers import AutoTokenizer
from verl.utils.reward_score.python_executor import execute_chemistry_code


def validate_format(text: str) -> tuple[bool, str]:
    # check if <think></think>, <answer></answer> is paired
    if text.count('<think>') != text.count('</think>'):
        return False, "<think> </think> not paired"
    
    if text.count('<think>') == 0 or text.count('</think>') == 0:
        return False, "<think> or </think> not found"
    
    # if text.count('<answer>') != 1 or text.count('</answer>') != 1:
    #     return False, "<answer> or </answer> not found"

    if text.count('<search>') > 0 or text.count('</search>') > 0 or text.count('<result>') > 0 or text.count('</result>') > 0:
        return False, "search/result tags are not allowed"
    if text.count('<python>') > 0 or text.count('</python>') > 0:
        return False, "python tags are not allowed"
    
    # check if \boxed{} is in the answer
    # answer_start = text.find('<answer>')
    # answer_end = text.find('</answer>')
    # if answer_start > answer_end:
    #     return False, "<answer> must be before </answer>"
    # answer_content = text[answer_start:answer_end]

    last_think_index = text.rfind("</think>")
    if last_think_index != -1:
        answer_content = text[last_think_index+len("</think>"):]
    else:
        return False, "final answer is missing </think> tag"

    if '\\boxed{' not in answer_content or '}' not in answer_content:
        return False, "answer is missing \\boxed{} format"
    
    return True, "format is correct"

# def extract_answer(text: str):
#     text = text.strip()

#     pattern = r"<answer>(.*?)</answer>"
#     match = re.search(pattern, text, re.DOTALL)
#     if not match:
#         return None
    # return match.group(1)

def extract_answer(text: str):
    text = text.strip()

    last_think_index = text.rfind("</think>")
    if last_think_index != -1:
        answer_content = text[last_think_index+len("</think>"):]
        return answer_content
    else:
        return None


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def get_score(query: str, prediction: str, ground_truths: Union[str, List[str]], method: str):
    result, report = execute_chemistry_code(query, prediction, ground_truths, method)
    if result:
        return 1, report
    else:
        return 0, report

def compute_score(tokenizer, solution_str, ground_truth, method) -> tuple[float, str]:
    # handling both the base model and the instruction-tuned model
    if "<｜Assistant｜><think>\n" in solution_str:
        solution_str_split = solution_str.split("<｜Assistant｜><think>\n")
    elif "<|im_start|>assistant\n" in solution_str:
        solution_str_split = solution_str.split("<|im_start|>assistant\n")
    else:
        solution_str_split = solution_str.split("Assistant:")

    query = solution_str.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
    response = solution_str_split[1]
    valid_template, reason = validate_format(response)
    if not valid_template:
        print(f"--------------------------------python bad format--------------------------------")
        return 0, f'python bad format: {reason}'

    if response.endswith(tokenizer.eos_token):
        response = response[:-len(tokenizer.eos_token)]
    
    answer_part = extract_answer(response)
    if answer_part is not None:
        try:
            answer = last_boxed_only_string(answer_part)
            answer = answer.replace("<SMILES>", "").replace("</SMILES>", "")
            f1_score, reason = get_score(query, answer, ground_truth, method)
            print(f"acc_score: {f1_score},answer: {answer},ground_truth: {ground_truth},reason: {reason}, method: {method}")
            if f1_score > 0:
                print(f"--------------------------------correct answer--------------------------------")
                return f1_score, f'correct answer, get acc score: {f1_score},reason: {reason}'
            else:
                print(f"--------------------------------wrong answer--------------------------------")
                return 0.1, f'wrong answer but good format: {answer},reason: {reason}'
        except Exception as e:
            print(f"--------------------------------find box error--------------------------------")
            return 0, f'find box error: {e}'
    else:
        print(f"--------------------------------cannot extract answer--------------------------------")
        return 0, f'cannot extract answer'
