import logging
import os
from typing import Any, Optional
from uuid import uuid4

from verl.utils.rollout_trace import rollout_trace_op
from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse
from recipe.fileagent.utils.metric_utils import build_tool_metric
from recipe.fileagent.reward_score import compute_score

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class RewardTool(BaseTool):
    """A tool for calculating the reward defined by FileAgent.

    - `get_openai_tool_schema`: return the tool schema in OpenAI format.
    - `create`: create a tool instance for a trajectory.
    - `execute`: execute the tool.
    - `calc_reward`: calculate the reward respect to tool state.
    - `release`: release the tool instance.
    """

    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        """
        _tool_schema = OpenAIFunctionToolSchema.model_validate({
            "type": "function",
            "function": {
                "name": "calc_reward",
                "description": "Evaluate the model predicted answer using the sample ground truth and return a reward.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "answer": {
                            "type": "string",
                            "description": "The answer to the question",
                        },
                    },
                    "required": ["answer"],
                },
            }
        })
        """
        super().__init__(config, tool_schema)
        self._instance_dict = {}

    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
        return self.tool_schema

    async def create(
        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs
    ) -> tuple[str, ToolResponse]:
        if instance_id is None:
            instance_id = str(uuid4())
        if ground_truth is None:
            ground_truth = kwargs.get("create_kwargs", {}).get("ground_truth", None)
        data_source = kwargs.get("create_kwargs", {}).get("data_source", None)
        if data_source is None:
            raise ValueError("data_source must be provided when creating the reward tool.")
        extra_info = kwargs.get("create_kwargs", {}).get("extra_info", None)
        self._instance_dict[instance_id] = {
            "response": "",
            "ground_truth": ground_truth,
            "reward": 0.0,
            "data_source": data_source,
            "extra_info": extra_info,
        }
        return instance_id, ToolResponse()

    @rollout_trace_op
    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
        answer = parameters.get("answer", "")
        if not isinstance(answer, str):
            answer = str(answer)
        
        if answer.startswith("<answer>"):
            self._instance_dict[instance_id]["response"] = answer
        else:
            self._instance_dict[instance_id]["response"] = "<answer>" + answer + "</answer>"

        reward = await self.calc_reward(instance_id)
        # penalty for non improved answer submission
        tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05
        # update the reward
        self._instance_dict[instance_id]["reward"] = reward

        # Tool Call Metric
        metric = build_tool_metric(tool_name=self.name, succeeded=True)

        return ToolResponse(text=f"Current parsed {answer=} {reward=}"), tool_reward, metric

    async def calc_reward(self, instance_id: str, **kwargs) -> float:
        score_dict = compute_score(
            data_source=self._instance_dict[instance_id]["data_source"],
            solution_str=self._instance_dict[instance_id]["response"],
            ground_truth=self._instance_dict[instance_id]["ground_truth"],
            extra_info=self._instance_dict[instance_id]["extra_info"],
        )
        return score_dict["acc_reward"]

    async def release(self, instance_id: str, **kwargs) -> None:
        del self._instance_dict[instance_id]