import json
import sys
import traceback
import requests
from types import ModuleType
import signal
from typing import Union
from transformers import AutoModel, AutoTokenizer
import torch
import torch
import torch.nn.functional as F

import re
from agents.rody_calling_mock import *
from agents.expert_calling_mock import *
from agents.service_utils import *

system_qa_check = """You need to act as a judge to score others. Please carefully and thoroughly understand the user’s conversation content and the response situation, and in combination with the content of the reference answer, act as an impartial judge to score the AI assistant’s reply to the user’s current question, as well as the accuracy of conveying the API information present in the question(if exists).  
### Input Format  
User's question:{question}  
[Reference answer start]  
{reference}  
[Reference answer end]  
******
[Assistant's answer start]  
{answer}  
[Assistant's answer end]  
******

### Output Format  
{{'Score': ..}}

### Scoring Process  
Follow these steps to evaluate the assistant's answer:  
1. **Relevance and Safety Check**: Determine if the assistant's reply contains harmful content, is irrelevant to the user's question, exhibits off-topic responses, or shows other clear signs of being entirely ineffective. If any of these issues exist, directly return a score of -1.  
Output format:  {{'Score': -1}}
2. **Accuracy and Similarity to Reference Answer**: When the assistant’s response is deemed valid and harmless, its accuracy is evaluated based on both its correctness and similarity to the reference answer, with a score returned in the range of [0, 1].
- When the question and the reference answer contain API invocation results such as “Pass calling… the API call was… Results are as follows:…”, although these descriptions are not part of the question itself, it is required that the answer fully conveys this API invocation information. Therefore, it is necessary to evaluate the similarity of whether the answer accurately conveys the API invocation information, mainly focusing on the name of the invoked API and its specific parameters.
- For non-API-related responses, the evaluation mainly considers the similarity to the reference answer and consistency with the question.
- Pay Attention!!!!! When the question and reference answer contain API invocation results, the similarity of this part accounts for 0.5 points, while the Q&A similarity constitutes the remaining 0.5 points. That means, if the answer misses this part, the maximum score it can get is 0.5.
- If no API invocation is involved, the Q&A similarity is scored out of 1 point.
Output format:  {{'Score': ..}}
Note: The score should be a 2-decimal value within [0, 1].eg.{{'Score': 0.75}}

### Instructions  
Analyze the user's question and the assistant's response based on the above scoring criteria. Before assigning a score, provide an evaluation and explanation, then output the final score in the specified format."""

user_init_qa = """User's question:{question}  
[Reference answer start]  
{reference}  
[Reference answer end]  
[Assistant's answer start]  
{answer}  
[Assistant's answer end]"""

system_math_check = """You are an expert proficient in mathematical problems and now need to act as a judge to score others’ answers. Please carefully understand the user’s conversation content and the standard answer, and evaluate whether the current assistant’s response is correct.
### Input Format  
User’s question: {question}
[Standard Answer Start] {reference} [Standard Answer End]
[Assistant’s Answer Start]
{answer}
[Assistant’s Answer End]

### Output Format  
If the assistant’s answer is correct, output TRUE; otherwise, output FALSE.

### Scoring Process  
You need to evaluate the assistant’s answer according to the following process:
1.Carefully read the question and the assistant’s final answer, extracting the assistant’s response result. Note that the final answer is usually at the end of the response.
2.Analyze whether the assistant’s answer matches the standard answer. If they are consistent, return TRUE; otherwise, output FALSE.

### Instructions  
Please analyze the user’s question and the answer situation, and directly provide the judgment in the specified output format based on the above scoring criteria."""

user_init_math = """User's question:{question}  
User’s question: {question}
[Standard Answer Start] {reference} [Standard Answer End]
[Assistant’s Answer Start]
{answer}
[Assistant’s Answer End]"""


class RewardModel():
    def __init__(self):

        self.FULL_REWARDS = 1  # A perfect score
        self.ZERO_REWARDS = 0  # No reward, no punishment
        self.FORMAT_REWARDS = 0.1  # Only format score
        self.ERROR_REWARDS = -1  # Define the score for the error (cannot be 0, which is equivalent to no penalty)
        self.ABNORMAL_REWARDS = -2  # Penalty for abnormal return

    def open_ai_llm_score(self, question, reference, answer, math_task=False):
        if math_task:
            system = system_math_check
            user_init = user_init_math
        else:
            system = system_qa_check
            user_init = user_init_qa
        messages = [
                {
                    "role": "system",
                    "content": system
                },
                {
                    "role": "user",
                    "content": user_init.format(question=question, reference=reference, answer=answer)
                }
            ]
        model = 'gpt-4o'
        output = call_llm_messages(model, messages)
        return output

    def score_extract(self, score_str):
        if score_str.find("{{") < 0 and score_str.find("}}") < 0:
            score_str = score_str.replace("{", "{{").replace("}", "}}")
        try:
            match = re.search(r"\{.*\}", score_str)
            dict_str = match.group(0).strip()
            dict_data = eval(dict_str[1:-1])['Score']
            return dict_data
        except:
            print('error', score_str)
            return None

    def compute_qa_answer_reward_score(self, query, real_answer, pred_answer):
        """
        Evaluate the effectiveness of expert and toolbench task model answers
        """
        llm_score = self.open_ai_llm_score(query, real_answer, pred_answer)
        score = self.score_extract(llm_score)
        if score is None:
            return 0
        else:
            return score

    def compute_math_answer_reward_score(self, query, real_answer, pred_answer):
        """
        Judging the effectiveness of the model's answers to math tasks
        """
        llm_result = self.open_ai_llm_score(query, real_answer, pred_answer, math_task=True)
        llm_result = llm_result.lower()
        final_score = self.FULL_REWARDS if 'true' in llm_result else self.ERROR_REWARDS
        return final_score

    def compute_numeric_consistency_score(self, query, num_real: str, num_pred: str) -> int:
        """
        Verify the numerical consistency score
        """

        def math_trans(num):
            """
            Converting characters to numeric values
            """
            if '$' in num:
                n = num.split('$')[1]
                return float(n)
            if '.' in num:
                return float(num)
            if '/' in num:
                return eval(num)
            return int(num)

        if not num_pred or not num_real:
            print(f"Error! Missing matching variable：num_pred is {num_pred}, num_real is {num_real}.\n")
            return self.ABNORMAL_REWARDS

        if not (isinstance(num_pred, str) and isinstance(num_real, str)):
            # If not, throw an exception
            print(
                f"Error! num_pred and num_real are both required to be string variables. The current character types are as follows：num_pred is {num_pred}, the type is {type(num_pred)}; num_real is{num_real}, the type is{type(num_real)}.\n")
            return self.ABNORMAL_REWARDS

        num_pred = num_pred.replace(',', '')
        num_real = num_real.replace(',', '')

        if num_pred.find('</think>') >= 0:
            num_pred = num_pred.split('</think>')[1]
        if num_real.find('</think>') >= 0:
            num_real = num_real.split('</think>')[1]

        if num_pred == num_real:  ## Character-level direct matching
            return self.FULL_REWARDS

        try:
            num_pred_numeric = math_trans(num_pred)
            num_real_numeric = math_trans(num_real)
            if num_pred_numeric == num_real_numeric:
                return self.FULL_REWARDS
            final_score = self.compute_math_answer_reward_score(query, num_real, num_pred)
            return final_score
        except:
            final_score = self.compute_math_answer_reward_score(query, num_real, num_pred)
            return final_score

    def compute_function_call_answer_score(self, pred_answer, real_answer):
        # Extract the called api name and first determine whether the api is selected correctly
        pattern = r"'api_name': '([^']+)'"
        pattern1 = r'"api_name": "([^"]+)"'
        real_api = ''
        pred_api = ''
        if real_answer.find('api_name') >= 0:
            match = re.search(pattern, real_answer)
            if match:
                real_api = match.group(1)
            else:
                match1 = re.search(pattern1, real_answer)
                if match1:
                    real_api = match1.group(1)
        if pred_answer.find('api_name') >= 0:
            match = re.search(pattern, pred_answer)
            if match:
                pred_api = match.group(1)
            else:
                match1 = re.search(pattern1, pred_answer)
                if match1:
                    pred_api = match1.group(1)

        if pred_answer == real_answer:
            return self.FULL_REWARDS
        else:
            if real_api == pred_api and real_api:
                return 0.5 * self.FULL_REWARDS
            else:
                return self.ERROR_REWARDS

    def valid_pred_answer_check(self, ans):
        valid_flag = False
        if ans.strip() == '':
            return valid_flag, self.ABNORMAL_REWARDS

        special_tokens = [('<tool_call>', '</tool_call>'), ('<think>', '</think>')]
        for start_token, end_token in special_tokens:
            if (start_token in ans and end_token not in ans) or (start_token not in ans and end_token in ans):
                return valid_flag, self.ABNORMAL_REWARDS

        fail_types_str = ['timeout planning failed', 'output parsing failed']
        for fail_type in fail_types_str:
            if fail_type in ans:
                return valid_flag, self.ABNORMAL_REWARDS

        if ("{" in ans or "}" in ans) or ("[" in ans or "]" in ans):
            try:
                match_result = eval(ans)
                return valid_flag, self.ABNORMAL_REWARDS
            except:
                pass

        if ans in ['The retrievals format is invalid.', 'answer']:
            return valid_flag, self.ERROR_REWARDS

        return True, None

    def get_reward(self, query_type, query, pred_answer, real_answer):
        valid_flag, score = self.valid_pred_answer_check(pred_answer)
        if not valid_flag:
            return score

        if query_type == 'expert':
            reward = self.compute_qa_answer_reward_score(query, real_answer, pred_answer)
        elif query_type == 'rody':
            reward = self.compute_function_call_answer_score(pred_answer, real_answer)
        elif query_type == 'toolbench':
            # Toolbench is the result returned by the comprehensive API calls
            reward = self.compute_qa_answer_reward_score(query, real_answer, pred_answer)
        elif query_type == 'math':
            reward = self.compute_numeric_consistency_score(query, real_answer, pred_answer)
        else:
            raise ValueError('no such query type! please check!')
        return reward
