# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import openai
import re 
import requests
import json

# llm checker 
def extract_score(text):
    match = re.search(r'\[\[(\d+)\]\]', text)
    try:
        return int(match.group(1))
    except:
        return 0


try:
    from math_verify.metric import math_metric
    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
    print("To use Math-Verify, please install it first by running `pip install math-verify`.")


def remove_think_tags(text: str) -> str:
    pattern = r"<think>.*?</think>"
    return re.sub(pattern, "", text, flags=re.DOTALL)


def extract_content_from_backticks(text, pattern=r"```(.*?)```"):
    extracted_contents = re.findall(pattern, text, re.DOTALL)
    return [content.strip() for content in extracted_contents]


def extract_from_boxed(text):
    pattern = r"\\boxed{(.*?)}"
    extracted_contents = re.findall(pattern, text, re.DOTALL)
    return [content.strip() for content in extracted_contents]


def check_in_list_truthful(model_output: str, ground_truth: list[str]) -> bool:
    # Wrap the ground truth in \boxed{} format for verification
    try:
        model_output = remove_think_tags(model_output)
        outputs = extract_from_boxed(model_output)
        print('check_in_list_truthful', outputs)
        if len(outputs) < 1 or len(outputs) > 1:
            return 0
        if outputs[0] == "I don't know":
            return 1
        if outputs[0] in ground_truth:
            return 1
        return 0
    except:
        return 0

def check_in_list_test(model_output: str, ground_truth: list[str]) -> bool:
    # Wrap the ground truth in \boxed{} format for verification
    try:
        model_output = remove_think_tags(model_output)
        outputs = extract_from_boxed(model_output)
        print('check_in_list_test', outputs)
        if outputs[0] in ground_truth:
            return 1.0
        return 0
    except:
        return 0
    

def request_for_self_model(
    prompt, 
    url, 
    apikey
) -> str:

    data = {
        "model": 'gpt-oss-120b',
        "stream": False,
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "max_tokens": 32000,
        "temperature": 0.6,
        "top_p": 0.95
    }
    
    headers = {
        "Authorization": apikey,
        "Content-Type": "application/json",
    }
    retry_times = 2
    for i in range(retry_times + 1):
        try:
            response = requests.post(
                url,
                headers=headers,
                json=data,  
                timeout=300
            )
            res = response.json()
            return res['choices'][0]['message']['content']
        except Exception as e:
            print(f"fail: {e}")
    return 'error'


def check_in_list(model_output: str, ground_truth: list[str]) -> bool:
    # Wrap the ground truth in \boxed{} format for verification
    try:
        model_output = remove_think_tags(model_output)
        outputs = extract_from_boxed(model_output)
        print('check_in_list', outputs)
        if len(outputs) < 1 or len(outputs) > 1:
            return -0.2
        if outputs[0] == "I don't know":
            return 0.1
        if outputs[0] in ground_truth:
            return 1.0
        return -0.2
    except:
        return -0.2


def check_idn(model_output: str, mode="UAM") -> bool:
    # Wrap the ground truth in \boxed{} format for verification
    try:
        model_output = remove_think_tags(model_output)
        outputs = extract_from_boxed(model_output)
        print('check_idn', outputs)
        if mode == 'UAM':
            if len(outputs) < 1 or len(outputs) > 1:
                return -0.2
        if outputs[0] == "I don't know":
            return 1
        return 0
    except:
        return 0


def compute_score(model_output: str, ground_truth: str) -> bool:
    try:
        verify_func = math_metric(
            gold_extraction_target=(LatexExtractionConfig(),),
            pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
        )
        ret_score = 0.0

        ground_truth_boxed = "\\boxed{" + ground_truth + "}"
        with contextlib.suppress(Exception):
            ret_score, _ = verify_func([ground_truth_boxed], [model_output])
    except:
        return 0
    return ret_score


def triviaqa_open_train_test(solution_str, extra_info, mode='test'):
    try:
        solution_str = remove_think_tags(solution_str)
        with open('prompt.txt', 'r') as f:
            prompt = ''.join(f.readlines()).strip()
            prompt = prompt.format(
                user_request=extra_info['user_request'], 
                context_document=extra_info["context_document"],
                response=solution_str
            )
        respose = request_for_self_model(prompt)  
        try:
            response = json.loads(respose)
        except:
            response = extract_content_from_backticks(respose, pattern="```json(.*?)```")[0]
            response = json.loads(response)
        panelty = 0
        final_score = 1
        for item in response['sentences_check']:
            if item['label'] in ['unsupported', 'contradictory']:
                final_score = 0
        if not response['request_completed'] or not response['all_sentences_grounded']:
            final_score = 0
        if response['has_formatting_errors']:
            panelty = 0.2
        assert response['completeness_score'] in [0, 1, 2]
        if final_score > 0:
            panelty += (2 - response['completeness_score']) / 10
        print('facts_grounding_test', panelty, final_score)
        return final_score - int(mode == 'train') * panelty
    except Exception as e:
        print("bug", e)
        final_score = 0
    return final_score


def reward_func(data_source, solution_str, ground_truth, extra_info=None):
    # breakpoint()
    if data_source in ["deepscaler"]:
        return compute_score(solution_str, ground_truth[0])
    elif data_source in ["triviaqa-train"]:
        return check_in_list(model_output=solution_str, ground_truth=ground_truth)
    elif data_source in ["UAM"]:
        return check_idn(model_output=solution_str, mode=data_source)
    elif data_source in ["fineweb-train"]:
        return triviaqa_open_train_test(solution_str=solution_str, extra_info=extra_info, mode="train")
    elif data_source in ["triviaqa-open-train"]:
        return triviaqa_open_train_test(solution_str=solution_str, extra_info=extra_info, mode="train")
    else:
        raise NotImplementedError

