import re
import sys
import string
from typing import Union, List
from collections import Counter
from transformers import AutoTokenizer
from verl.utils.reward_score.python_executor import execute_chemistry_code
# from reward_score.python_executor import execute_chemistry_code


def validate_format(text: str) -> tuple[bool, str]:
    # check if <think></think>, <answer></answer> is paired
    if text.count('<think>') != text.count('</think>'):
        return False, "<think> </think> not paired"
    
    if text.count('<think>') == 0 or text.count('</think>') == 0:
        return False, "<think> or </think> not found"
    
    # if text.count('<answer>') != 1 or text.count('</answer>') != 1:
    #     return False, "<answer> or </answer> not found"        
    
    # check the order of search/result
    current_pos = 0
    while True:
        search_pos = text.find('<search>', current_pos)
        if search_pos == -1:
            break
            
        result_pos = text.find('<result>', search_pos)
        search_end_pos = text.find('</search>', search_pos)
        result_end_pos = text.find('</result>', result_pos)
        
        if -1 in (result_pos, search_end_pos, result_end_pos):
            return False, "search/result tags are incomplete"
            
        if not (search_pos < search_end_pos < result_pos < result_end_pos):
            return False, "search/result tags are nested in the wrong order"
            
        current_pos = result_end_pos
    
    # check the order of python/result
    current_pos = 0
    while True:
        python_pos = text.find('<python>', current_pos)
        if python_pos == -1:
            break
            
        result_pos = text.find('<result>', python_pos)
        python_end_pos = text.find('</python>', python_pos)
        result_end_pos = text.find('</result>', result_pos)
        
        if -1 in (result_pos, python_end_pos, result_end_pos):
            return False, "python/result tags are incomplete"
            
        if not (python_pos < python_end_pos < result_pos < result_end_pos):
            return False, "python/result tags are nested in the wrong order"
            
        current_pos = result_end_pos
    
    # check if \boxed{} is in the answer
    # answer_start = text.find('<answer>')
    # answer_end = text.find('</answer>')
    # if answer_start > answer_end:
    #     return False, "<answer> must be before </answer>"
    # answer_content = text[answer_start:answer_end]

    last_think_index = text.rfind("</think>")
    if last_think_index != -1:
        answer_content = text[last_think_index+len("</think>"):]
    else:
        return False, "final answer is missing </think> tag"

    if '\\boxed{' not in answer_content or '}' not in answer_content:
        return False, "answer is missing \\boxed{} format"
    
    return True, "format is correct"



def validate_format_python(text: str) -> tuple[bool, str]:
    # check if <think></think>, <answer></answer> is paired
    if text.count('<think>') != text.count('</think>'):
        return False, "<think> </think> not paired"
    
    if text.count('<think>') == 0 or text.count('</think>') == 0:
        return False, "<think> or </think> not found"
    
    if text.count('<answer>') != 1 or text.count('</answer>') != 1:
        return False, "<answer> or </answer> not found"        
    
    # check the order of search/result
    current_pos = 0
    while True:
        search_pos = text.find('<python>', current_pos)
        if search_pos == -1:
            break
            
        result_pos = text.find('<result>', search_pos)
        search_end_pos = text.find('</python>', search_pos)
        result_end_pos = text.find('</result>', result_pos)
        
        if -1 in (result_pos, search_end_pos, result_end_pos):
            return False, "python/result tags are incomplete"
            
        if not (search_pos < search_end_pos < result_pos < result_end_pos):
            return False, "python/result tags are nested in the wrong order"
            
        current_pos = result_end_pos
    
    # check if \boxed{} is in the answer
    answer_start = text.find('<answer>')
    answer_end = text.find('</answer>')
    if answer_start > answer_end:
        return False, "<answer> must be before </answer>"
    answer_content = text[answer_start:answer_end]
    if '\\boxed{' not in answer_content or '}' not in answer_content:
        return False, "answer is missing \\boxed{} format"
    
    return True, "format is correct"

# def extract_answer(text: str):
#     text = text.strip()

#     pattern = r"<answer>(.*?)</answer>"
#     match = re.search(pattern, text, re.DOTALL)
#     if not match:
#         return None
    
#     return match.group(1)

def extract_answer(text: str):
    text = text.strip()

    last_think_index = text.rfind("</think>")
    if last_think_index != -1:
        answer_content = text[last_think_index+len("</think>"):]
        return answer_content
    else:
        return None


def remove_boxed(s):
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[:len(left)] == left
        return s[len(left):]

    left = "\\boxed{"

    assert s[:len(left)] == left
    assert s[-1] == "}"

    return s[len(left):-1]

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

# def get_f1_score(prediction: str, ground_truths: Union[str, List[str]]):
#     if isinstance(ground_truths, str):
#         ground_truths = [ground_truths]
    
#     final_metric = {"f1": 0, "precision": 0, "recall": 0}

#     for ground_truth in ground_truths:
#         normalized_prediction = normalize_answer(prediction)
#         normalized_ground_truth = normalize_answer(ground_truth)

#         if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
#             continue
        
#         if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
#             continue

#         prediction_tokens = normalized_prediction.split()
#         ground_truth_tokens = normalized_ground_truth.split()
#         common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
#         num_same = sum(common.values())
#         if num_same == 0:
#             continue
        
#         precision = 1.0 * num_same / len(prediction_tokens)
#         recall = 1.0 * num_same / len(ground_truth_tokens)
#         f1 = (2 * precision * recall) / (precision + recall)
        
#         final_metric["precision"] = max(precision, final_metric["precision"])
#         final_metric["recall"] = max(recall, final_metric["recall"])
#         final_metric["f1"] = max(f1, final_metric["f1"])
    
#     return final_metric['f1']


def get_score(query: str, prediction: str, ground_truths: Union[str, List[str]], method: str):
    result, report = execute_chemistry_code(query, prediction, ground_truths, method)
    if result:
        return 1, report
    else:
        return 0, report



def compute_score(tokenizer, solution_str, ground_truth, method) -> tuple[float, str]:
    # handling both the base model and the instruction-tuned model
    if "<｜Assistant｜><think>\n" in solution_str:
        solution_str_split = solution_str.split("<｜Assistant｜><think>\n")
    elif "<|im_start|>assistant\n" in solution_str:
        solution_str_split = solution_str.split("<|im_start|>assistant\n")
    else:
        solution_str_split = solution_str.split("Assistant:")

    query = solution_str.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
    response = solution_str_split[1]
    valid_template, reason = validate_format(response)
    if not valid_template:
        print(f"--------------------------------python bad format--------------------------------")
        return 0, f'python bad format: {reason}'

    if response.endswith(tokenizer.eos_token):
        response = response[:-len(tokenizer.eos_token)]
    


    answer_part = extract_answer(response)
    if answer_part is not None:
        try:
            answer = last_boxed_only_string(answer_part)
            answer = answer.replace("<SMILES>", "").replace("</SMILES>", "")
            f1_score, reason = get_score(query, answer, ground_truth, method)
            print(f"acc_score: {f1_score},answer: {answer},ground_truth: {ground_truth},reason: {reason}, method: {method}")
            if f1_score > 0:
                print(f"--------------------------------correct answer--------------------------------")
                return f1_score, f'correct answer, get acc score: {f1_score},reason: {reason}'
            else:
                print(f"--------------------------------wrong answer--------------------------------")
                return 0.1, f'wrong answer but good format: {answer},reason: {reason}'
        except Exception as e:
            print(f"--------------------------------find box error--------------------------------")
            return 0, f'find box error: {e}'
    else:
        print(f"--------------------------------cannot extract answer--------------------------------")
        return 0, f'cannot extract answer'

