

from dataclasses import dataclass
from enum import Enum

@dataclass
class RewardConfig:

    math_reward_weight: float = 1.0
    use_math_orm: bool = False


    code_reward_weight: float = 1.0
    

    cot_reward_weight: float = 0.0
    

    correct_reward: float = 1.0
    incorrect_reward: float = 0.0
    format_error_reward: float = 0.0
    unk_error_reward: float = 0.0
    

    toolcall_bonus: float = 0.5


class RewardType(Enum):

    MATH = 'MATH'
    CODE = 'CODE'
    UNK = 'UNK'


@dataclass
class RewardInput:

    problem: str
    data_source: str
    model_response: str
    metadata: dict
    problem_type: RewardType = RewardType.UNK



@dataclass
class LiveCodebenchInput:

    question: str
    generation_code: str
    problem: dict
    difficult:str='easy'
    problem_type: RewardType = RewardType.CODE


@dataclass
class RewardOutput:

    reward: float
    is_correct: bool


class RewardFn:

    def __init__(self, config: RewardConfig):
        self.config = config

    def __call__(self, input: RewardInput) -> RewardOutput:
        raise NotImplementedError("Subclasses must implement this method.")

