import re
from core.reasoning.style import DirectStyle
from core.reasoning.rm import RewardModel
from ..utils import In


class DirectMultStyle(DirectStyle[In, int]):

    def repr_input(self, input: In, **kwargs) -> str:
        x, y = input
        return "%d * %d" % (x, y)

    def repr_outcome(self, outcome: int, **kwargs) -> str:
        return "Outcome: %d" % outcome


def direct_multiplier(input: In):
    return None, input[0] * input[1]  # thought=None, outcome=x*y


class RE:

    white_space = re.compile(r"\s+")
    input = re.compile(r"<in>(\d+)\*(\d+)</in>")
    outcome = re.compile(r"<out>Outcome:(\d+)</out>")

    @staticmethod
    def remove_white_space(text: str):
        return re.sub(RE.white_space, '', text)


def parse_outcome(text: str):
    m = re.search(RE.outcome, RE.remove_white_space(text))
    if m is None:
        return None
    try:
        outcome = m.group(1)
        outcome = int(outcome)
        return outcome
    except ValueError:
        return None
    

class DirectRM(RewardModel):

    def __init__(
        self,
        syntax_error: float = -1,
        value_error: float = -1,
        correct_outcome: float = 1,
    ):
        super().__init__()

        self.syntax_error: float = syntax_error
        self.value_error: float = value_error
        self.correct_outcome: float = correct_outcome

    def outcome_reward(self, _outcome: str, **references) -> float:
        input = references.get('input')
        if input is None:
            raise ValueError("reward model requires \"input\" for reference.")
        x, y = input
        assert isinstance(x, int) and isinstance(y, int)
        answer = x * y
        outcome = parse_outcome(_outcome)
        if outcome is None:
            return self.syntax_error
        elif outcome == answer:
            return self.correct_outcome
        else:
            return self.value_error
