from loguru import logger
import os
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4

from distflow.utils.reward_score import gsm8k

from .base import BaseInteraction



class Gsm8kInteraction(BaseInteraction):
    """A demo interaction for calculating the reward of gsm8k.

    - `start_interaction`: start a interaction instance for a trajectory.
    - `generate_response`: generate the response of the user.
    - `calculate_score`: calculate the score of the interaction.
    - `finalize_interaction`: finalize the interaction instance.
    """

    def __init__(self, config: dict):
        super().__init__(config)
        self._instance_dict = {}

    async def start_interaction(
        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs
    ) -> str:
        if instance_id is None:
            instance_id = str(uuid4())
        self._instance_dict[instance_id] = {
            "response": "",
            "ground_truth": ground_truth,
            "reward": 0.0,
        }
        return instance_id

    async def generate_response(
        self, instance_id: str, messages: List[Dict[str, Any]], **kwargs
    ) -> Tuple[bool, str, float, dict]:
        content = ""
        for i in range(len(messages) - 1, -1, -1):
            item = messages[i]
            if item.get("role") == "user":
                content = item.get("content")
                break

        if content and content.startswith("#### "):
            self._instance_dict[instance_id]["response"] = content
        else:
            self._instance_dict[instance_id]["response"] = "#### " + (content or "")

        reward = await self.calculate_score(instance_id)
        if reward == 1.0:
            response = "Your response is correct!"
            should_terminate_sequence = True
        else:
            response = "Your response is incorrect! You need to reflect on your answer and try again."
            should_terminate_sequence = False

        return should_terminate_sequence, response, reward, {}

    async def calculate_score(self, instance_id: str, **kwargs) -> float:
        return gsm8k.compute_score(
            self._instance_dict[instance_id]["response"],
            self._instance_dict[instance_id]["ground_truth"],
            method="flexible",
            format_score=0.0,
            score=1.0,
        )

    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
        del self._instance_dict[instance_id]
