# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import string
import random

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def em_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer == normalized_prediction:
            score = 1
            break
    return score


def restore_multilogue(text):
    pattern = re.compile(r'<\|im_start\|>(\w+)\n(.*?)<\|im_end\|>', re.DOTALL)
    parts = pattern.findall(text)
    
    dialog = []
    for role, content in parts:
        role = role.lower()
        dialog.append({'role': role, 'content': content.strip()})

    return dialog


def check_strategy(conv):
    content = conv['content']
    split_pattern = r"(</?strategy>)"
    parts = re.split(split_pattern, content)
    content = ''.join([x.strip() for x in parts if x.strip()])
    match_pattern = r'^<strategy>(.*?)</strategy>'

    return bool(re.fullmatch(match_pattern, content))


def check_action_loop(convs):
    assistant_contents = [conv['content'] for conv in convs if conv['role'] == 'assistant']
    split_pattern = r"(</?(?:think|search|reflect|answer)>)"
    
    for i, content in enumerate(assistant_contents):
        parts = re.split(split_pattern, content)
        content = ''.join([x.strip() for x in parts if x.strip()])
        if i % 2 == 0:
            match_pattern = r'^<think>(.*?)</think><?(?:search|answer)>(.*?)</?(?:search|answer)>'
        else:
            match_pattern = r'^<think>(.*?)</think><reflect>(.*?)</reflect>'
        
        match = bool(re.fullmatch(match_pattern, content))

        if not match:
            return False
        
    return True


def is_valid_sequence(text):
    convs = restore_multilogue(text)

    if check_strategy(convs[2]) and check_action_loop(convs[3:]):
        return True
    
    return False


def extract_solution(solution_str):
    """Extract the equation from the solution string."""

    answer_pattern = r'<answer>(.*?)</answer>'
    match = re.finditer(answer_pattern, solution_str, re.DOTALL)
    matches = list(match)
    
    # If there are 0 or exactly 1 matches, return None
    if len(matches) <= 1:
        return None
    
    # If there are 2 or more matches, return the last one
    return matches[-1].group(1).strip()


def extract_learnings_blocks(text: str) -> list[str]:
    pattern = r"<learnings>(.*?)</learnings>"
    matches = re.findall(pattern, text, re.DOTALL)
    return [match.strip() for match in matches]


def is_retrieval_correct(text: str, golden_answers: list[str]) -> list[str]:
    seqs = extract_learnings_blocks(text)
    for seq in seqs:
        for golden_answer in golden_answers:
            if normalize_answer(golden_answer) in normalize_answer(seq):
                return True
    return False


def compute_score_em(solution_str, ground_truth, method='strict', structure_format_score=0.2, final_format_score=0.1, retrieval_score=0.1, score=1.):
    """The scoring function for exact match (EM).

    Args:
        solution_str: the solution text
        ground_truth: the ground truth
        method: the method to extract the solution, choices are 'strict' and 'flexible'
        format_score: the score for the format
        score: the score for the correct answer
    """
    is_valid_format = is_valid_sequence(solution_str)
    retrieval_correct = False
    if is_valid_format:
        retrieval_correct = is_retrieval_correct(solution_str, ground_truth['target'])
    answer = extract_solution(solution_str=solution_str)
    do_print = random.randint(1, 64) == 1
    
    if do_print:
        print(f"--------------------------------")
        print(f"Golden answers: {ground_truth['target']}")
        print(f"Extracted answer: {answer}")
        print(f"Solution string: {solution_str}")
            
    if answer is None:
        if is_valid_format:
            if retrieval_correct:
                return structure_format_score + retrieval_score # 0.3
            else:
                return structure_format_score # 0.2
        else:
            return 0
    else:
        if em_check(answer, ground_truth['target']):
            if is_valid_format:
                return score # 1
            else:
                return score - structure_format_score # 0.8
        elif is_valid_format:
            if retrieval_correct:
                return structure_format_score + retrieval_score # 0.3
            else:
                return structure_format_score # 0.2
        else:
            return final_format_score # 0.1


if __name__ == '__main__':
    solution_str = '''<|im_start|>system\n<answer></answer><|im_end|>\n<|im_start|>user\nInitial<|im_end|>\n<|im_start|>assistant\n<strategy>s</strategy><|im_end|>\n<|im_start|>user\ne<|im_end|>\n<|im_start|>assistant\n<think>THINK0</think>\n<search>\n<query>QUERY0</query><goal>GOAL0</goal>\n</search><|im_end|>\n<|im_start|>user\nl<|im_end|>\n<|im_start|>assistant\n<think>t</think>\n<reflect>True</reflect><|im_end|>\n<|im_start|>user\nGood<|im_end|>\n<|im_start|>assistant\n<think>THINK</think>\n<search>\n<query>QUERY</query><goal>GOAL</goal>\n</search><|im_end|>\n<|im_start|>user\nl<|im_end|>\n<|im_start|>assistant\n<think>1975.</think>\n<reflect>True</reflect><|im_end|>\n<|im_start|>assistant\n<think>The search results indicate that Light In The Attic Records has re-released works by Jim Sullivan, who disappeared in 1975.</think>\n<answer>Jim Sullivan</answer><|im_end|>\n
    '''
    ground_truth = {'target': ['Jim Sullivan']}
    score = compute_score_em(solution_str=solution_str,ground_truth=ground_truth)
    print(score)
