VERIFY_SYSTEM_PROMPT_EN = """## Role
Please evaluate the agent based on each step of its actions and the final resolution of the instruction.

## Task
Based on (1) the task instruction (2) the agent's actions according to the environment's feedback (3) the agent's final resolution (4) the reference interaction of the task
Judge the agent's actions as well as final resolution, and give a score of 0/1/2
2-point answer criteria:
The final resolution is completely consistent with the goal of the task, and the actions are reasonable and effective compared to the reference interaction.
1-point answer criteria:
The final resolution is closely consistent with the goal of the task, and the actions are relatively reasonable and effective compared to the reference interaction.
0-point answer criteria:
The final resolution is inconsistent with the goal of the task, and the actions are not reasonable and effective enough compared to the reference interaction.
Please output the specific score after giving the scoring reason

## Input
{{
    "instruction": "instruction",
    "interaction": "agent's actions and environment's feedback, as well as the final resolution",
    "reference_interaction": "reference interaction",
}}

## Output Format
Please return in the json format below
```json
{{
    "reason": "scoring reason",
    "score": 0/1/2
}}
```
""".strip()

VERIFY_USER_PROMPT_EN = """## Instruction:
{instruction}
## Agent's Interaction:
{observation}
## Reference Interaction:
{conversations}
""".strip()

import re
import time
from my_reward.utils.time_utils import timeprint
from my_reward.api import oneapi_post_by_langchain, read_json
from my_reward.auxiliary.format_reward import (
    get_think_and_answer
)
from my_reward.contrib.base import RewardActorBase
from pydantic import BaseModel, Field

class Score(BaseModel):
    reason: str = Field(..., title="Scoring Reason", description="Scoring Reason")
    score: float = Field(..., title="Score", description="Score")

class RewardActorAgentStage1(RewardActorBase):

    @classmethod
    def batch_compute_score(
        cls, 
        params, 
        data_source_list, 
        prompt_str_list, 
        response_str_list, 
        ground_truth_list, 
        extra_info_list, 
        finish_reason_list=None
    ):
        if finish_reason_list is None:
            finish_reason_list = [None] * len(prompt_str_list)

        result = []
        for _ in response_str_list:
            result.append({
                "reason": "DEFAULT",
                "reward": cls.default
            })

        system_prompt = VERIFY_SYSTEM_PROMPT_EN
        index_list = []
        prompt_list = []
        for i, (data_source, prompt_str, response_str, ground_truth, extra_info, finish_reason) in enumerate(zip(data_source_list, prompt_str_list, response_str_list, ground_truth_list, extra_info_list, finish_reason_list)):

            format_score = cls.compute_format_score(prompt_str, response_str, finish_reason)
            if format_score != 1.0:
                result[i] = {
                    "reason": "FORMAT_WRONG",
                    "reward": cls.default + format_score / 10.0
                }
                continue

            _, answer_str = get_think_and_answer(response_str)
            
            conversations = extra_info["conversations"]
            instruction = extra_info["instruction"]

            prompt = VERIFY_USER_PROMPT_EN.format(
                instruction=instruction,
                observation=answer_str,
                conversations=conversations
            )
            prompt_list.append(prompt)
            index_list.append(i)

        batch_size = 256
        response_list = []
        for i in range(0, len(prompt_list), batch_size):
            stt = time.time()
            response_list += oneapi_post_by_langchain(
                prompt=prompt_list[i:i+batch_size],
                system_prompt=system_prompt,
                # base_model=Score,
                **params
            )
            edt = time.time()
            timeprint(f"-------- agent stage 1 compute score batch size: {len(prompt_list[i:i+batch_size])}, oneapi time: {edt - stt} s")
        
        for index, res in zip(index_list, response_list):
            try:
                res_json = read_json(res)
                result[index] = {
                    "reason": res_json["reason"],
                    "reward": 1.0 if float(res_json["score"]) >= 2.0 else (0.4 if float(res_json["score"]) >= 1.0 else cls.default_format_score)
                }
            except Exception as e:
                result[index] = {
                    "reason": f"ERROR IN VERIFY: {res}",
                    "reward": cls.default_format_score,
                    "exception": str(e)
                }

        return cls.add_penalty(result, prompt_str_list, response_str_list, extra_info_list)