VERIFY_SYSTEM_PROMPT_EN = """## Role
Please evaluate the agent based on each step of its actions and the final resolution of the instruction.

## Task
Based on (1) the task instruction (2) the agent's actions according to the environment's feedback (3) the agent's final resolution (4) the reference interaction of the task
Judge the agent's actions as well as final resolution, and give a score of 0/1/2
2-point answer criteria:
The final resolution is completely consistent with the goal of the task, and the actions are reasonable and effective compared to the reference interaction.
1-point answer criteria:
The final resolution is closely consistent with the goal of the task, and the actions are relatively reasonable and effective compared to the reference interaction.
0-point answer criteria:
The final resolution is inconsistent with the goal of the task, and the actions are not reasonable and effective enough compared to the reference interaction.
Please output the specific score after giving the scoring reason

## Input
{{
    "instruction": "instruction",
    "interaction": "agent's actions and environment's feedback, as well as the final resolution",
    "reference_interaction": "reference interaction",
}}

## Output Format
Please return in the json format below
```json
{{
    "reason": "scoring reason",
    "score": 0/1/2
}}
```
""".strip()

VERIFY_USER_PROMPT_EN = """## Instruction:
{instruction}
## Agent's Interaction:
{observation}
## Reference Interaction:
{conversations}
""".strip()

import re
import time
import numpy as np
from tqdm import tqdm
import json
import copy
import requests
import traceback
from typing import Dict, List

from my_reward.api import (
    oneapi_post,
    oneapi_post_by_langchain,
    read_json
)
from my_reward.auxiliary.format_reward import (
    get_think_and_answer
)
from my_reward.contrib.base import RewardActorBase, Reason
# from pydantic import BaseModel, Field

# class Score(BaseModel):
#     reason: str = Field(..., title="Scoring Reason", description="Scoring Reason")
#     score: float = Field(..., title="Score", description="Score")

class RewardActorAgentStage2(RewardActorBase):

    @classmethod
    def compute_score(
        cls, 
        params, 
        data_source, 
        prompt_str, 
        response_str, 
        ground_truth, 
        extra_info, 
        global_plan_score=0.0, 
        # finish_reason_list=None,
    ):
        # hypermeter
        alpha = 0.8
        
        # if finish_reason_list is None:
        #     finish_reason_list = [None] * len(prompt_str_list)
            
        format_score = cls.compute_format_score(prompt_str, response_str)
        if format_score == 0.0:
            result = {
                "reason": Reason.FORMAT_WRONG.value,
                "reward": cls.default,
            }
        else:
            system_prompt = VERIFY_SYSTEM_PROMPT_EN
            conversations = extra_info["conversations"]
            instruction = extra_info["instruction"]
            prompt = VERIFY_USER_PROMPT_EN.format(
                instruction=instruction,
                observation=response_str,
                conversations=conversations
            )
            
            # answer verify
            stt = time.time()
            retries = cls.api_retries
            answer_response = None
            result = {}
            while retries > 0:
                answer_response = oneapi_post(
                    # prompt=prompt,
                    # system_prompt=system_prompt,
                    prompt=system_prompt + "\n" + prompt,
                    # **params,
                    url=params["url"],
                    model=params["model"],
                    key=params.get("key", "EMPTY"),
                    max_tokens=params.get("max_tokens", 4096),
                    temperature=params.get("temperature", 0.9),
                    top_p=params.get("top_p", 0.6)
                )
                result = cls.get_final_reward(answer_response)
                if result.get("exception"):
                    retries -= 1
                    continue
                break
            print(f"########### Time for verify answer response: {time.time() - stt}")
            
            if retries == 0:
                result = {
                    "reward": cls.default,
                    "reason": Reason.API_ERROR.value,
                    "exception": "Failed to get verify score after retries",
                }
            if result.get("exception") is None:
                # avg_global_plan_score = sum(global_plan_score) / len(global_plan_score) if global_plan_score and len(global_plan_score) > 0 else 0.0
                # avg_global_plan_score = handle_global_plan_score(global_plan_score)  
                # result["reward"] = max(result["reward"] + alpha * avg_global_plan_score, 0.0)
                result["reward"] = max(result["reward"] + alpha * float(global_plan_score) if is_number(global_plan_score) else result["reward"], 0.0)
        
        result["reward"] = cls.add_penalty(result["reward"], prompt_str, response_str, extra_info)
        return result
    

# # handle the global plan score and return the average score
# def handle_global_plan_score(global_plan_score_list):
#     """
#     1. default: [[], [], ...]
#     2. [[{}, {}, ...], [{}, {}, ...], ...]
#     """
#     if not global_plan_score_list or len(global_plan_score_list) == 0 or all(isinstance(item, list) and len(item) == 0 for item in global_plan_score_list):
#         return 0.0
#     else:
#         avg_score_list = []
#         for item in global_plan_score_list:
#             correctness = float(item["correctness_score"]) if is_number(item["correctness_score"]) else 0.0
#             followability = float(item["followability_score"]) if is_number(item["followability_score"]) else 0.0
#             standardization = float(item["standardization_score"]) if is_number(item["standardization_score"]) else 0.0
#             avg_score = (correctness + followability + standardization) / 3.0
#             avg_score_list.append(avg_score)
#         total_avg_score = sum(avg_score_list) / len(avg_score_list) if avg_score_list and len(avg_score_list) > 0 else 0.0
#         return total_avg_score


def is_number(element):
    if isinstance(element, (int, float)):
        return True
    elif isinstance(element, str):
        try:
            float(element)
            return True
        except ValueError:
            return False
    else:
        return False