import logging
import os
from typing import Any, Optional
from uuid import uuid4
from verl.utils.reward_score import gsm8k
from .base import BaseInteraction
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class Gsm8kInteraction(BaseInteraction):
    def __init__(self, config: dict):
        super().__init__(config)
        self._instance_dict = {}
    async def start_interaction(
        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs
    ) -> str:
        if instance_id is None:
            instance_id = str(uuid4())
        self._instance_dict[instance_id] = {
            "response": "",
            "ground_truth": ground_truth,
            "reward": 0.0,
        }
        return instance_id
    async def generate_response(
        self, instance_id: str, messages: list[dict[str, Any]], **kwargs
    ) -> tuple[bool, str, float, dict]:
        content = ""
        for i in range(len(messages) - 1, -1, -1):
            item = messages[i]
            if item.get("role") == "assistant":
                content = item.get("content")
                break
        self._instance_dict[instance_id]["response"] = content
        reward = await self.calculate_score(instance_id)
        if reward == 1.0:
            response = "Your response is correct!"
            should_terminate_sequence = True
        else:
            response = "Your response is incorrect! You need to reflect on your answer and try again."
            should_terminate_sequence = False
        return should_terminate_sequence, response, reward, {}
    async def calculate_score(self, instance_id: str, **kwargs) -> float:
        return gsm8k.compute_score(
            self._instance_dict[instance_id]["response"],
            self._instance_dict[instance_id]["ground_truth"],
            method="strict",
            format_score=0.0,
            score=1.0,
        )
    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
        del self._instance_dict[instance_id]