# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from mathruler.grader import extract_boxed_content, grade_answer

def extract_final_answer(resp_text):
    match = re.search(r"<answer>(.*?)</answer>", resp_text, re.DOTALL)
    return match.group(1).strip() if match else None

def format_reward(predict_str: str) -> float:
    pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
    match_result = re.fullmatch(pattern, predict_str)
    return 1.0 if match_result else 0.0


def acc_reward(predict_str: str, ground_truth: str, use_boxed: bool = True) -> float:
    if use_boxed:
        answer = extract_boxed_content(predict_str)
    else:
        answer = predict_str
    return 1.0 if grade_answer(answer, ground_truth) else 0.0


def compute_score(predict_str: str, ground_truth: str, use_boxed: bool = True, format_score: float = 0.1, llm_reasoner_resp = None): # -> float:
    # import ipdb; ipdb.set_trace()
    if llm_reasoner_resp:
        llm_reasoner_reward = acc_reward(llm_reasoner_resp, ground_truth, use_boxed)   #extract_boxed_content(llm_reasoner_resp)
        Format_reward = format_reward(predict_str)
        Acc_reward = acc_reward(predict_str, ground_truth, use_boxed)
        
        ### first solution:
        # if Acc_reward == 1.0:
        #     Acc_reward = 2.0 # acc + vision
        # elif Acc_reward == 0.0 and llm_reasoner_reward == 1.0:
        #     Acc_reward = 1.0
        
        # return (1.0 - format_score) * Acc_reward + format_score * Format_reward

        ### second solution:
        if Acc_reward == 1.0:
            Vision_reward = 1.0 # acc + vision
        elif Acc_reward == 0.0 and llm_reasoner_reward == 1.0:
            Vision_reward = 1.0
            Format_reward = 0.0 ### add this only if applying vision reward to unmasked tokens only.
        else:
            Vision_reward = 0.0
        
        max_reward = 1
        
        if max_reward == 2: 
            score = (1.0 - format_score) * (Acc_reward+Vision_reward) + format_score * Format_reward
        elif max_reward == 1:
            # score = (1.0 - format_score) * Acc_reward + format_score * Format_reward
            score = (1.0 - format_score) * Vision_reward + format_score * Format_reward
        
        score_vision = (1.0 - format_score) * Vision_reward + format_score * Format_reward
        return {"score": score,
                "score_vision": score_vision,
                "accuracy_reward": Acc_reward,
                "format_reward": Format_reward,
                "vision_reward": Vision_reward}
    else:
        
        
        Format_reward = format_reward(predict_str)
        Acc_reward = acc_reward(predict_str, ground_truth, use_boxed)
        if Acc_reward == 1.0:
            Vision_reward = 1.0
        else:
            Vision_reward = 0.0
        score = (1.0 - format_score) * (Acc_reward) + format_score * Format_reward
        score_vision = (1.0 - format_score) * Vision_reward + format_score * Format_reward
        # return {"score": score,
        #         "score_vision": score_vision,
        #         "accuracy_reward": Acc_reward,
        #         "format_reward": Format_reward,
        #         "vision_reward": Vision_reward}
        return score
