"""
This module defines data structures and base classes for reward calculations
to evaluate model responses for various problem types, including math and coding.
"""

from dataclasses import dataclass, field
from enum import Enum

@dataclass
class RewardConfig:
    # Use LLM as ORM to evaluate correctness.
    use_math_orm: bool = False
    use_code_exe_reward: bool = False
    
    # General reward constants.
    correct_reward: float = 1.0
    incorrect_reward: float = -1.0
    format_error_reward: float = -1.0
    unk_error_reward: float = -1.0


class RewardType(Enum):
    """
    Enum class representing the different types of rewards that can be assigned.

    Attributes:
        MATH (str): Represents a math-related problem type.
        CODE (str): Represents a coding-related problem type.
        UNK (str): Represents an unknown or unclassified problem type.
    """
    MATH = 'MATH'
    CODE = 'CODE'
    UNK = 'UNK'


@dataclass
class RewardInput:
    """Data structure for input required to calculate rewards.

    Attributes:
        problem (str): The original problem text or prompt provided to the model.
        model_response (str): The response generated by the model that needs evaluation.
        problem_type (RewardType): The category of the problem (e.g., math, code) to be evaluated.
        ground_truth (dict): Additional contextual information necessary for evaluation:
            - For math problems: This may include the ground truth answer.
            - For coding problems: This may include unit tests to validate the solution.
    """
    problem: str
    model_response: str
    problem_type: RewardType = RewardType.UNK
    ground_truth: dict = field(default_factory=dict)


@dataclass
class RewardOutput:
    """Data structure for the output of reward calculations.

    Attributes:
        reward (float): The computed reward value based on the evaluation of the model's response.
        is_correct (bool): A boolean flag indicating whether the model's response is deemed correct.
    """
    reward: float
    is_correct: bool


class RewardFn:
    """Abstract base class for defining reward calculation strategies.

    This class should be subclassed to implement specific reward calculation logic.
    The __call__ method must be overridden to provide the functionality for evaluating
    the input and returning the corresponding reward output.
    """
    def __init__(self, config: RewardConfig):
        self.config = config

    def __call__(self, input: RewardInput) -> RewardOutput:
        raise NotImplementedError("Subclasses must implement this method.")