import numpy as np
import numpy.random as rd
import re
from core.reasoning.rm import RewardModel
from core.reasoning.formulation import StepReasoner
from core.reasoning.style import StepwiseStyle
from ..utils import In


class MultistepMultiplier(StepReasoner[In, list[str], int]):

    def __init__(self, noise: float = 0, completeness: float = 1):
        self._noise = noise
        self._completeness = completeness
        
        if noise:
            self._correct = lambda: rd.rand() >= noise
        else:
            self._correct = lambda: True

        self._items: list[tuple[int, int]] = []
        self._steps: list[str] = []

    def __mkvstate__(self) -> list[tuple[int, int]]:
        return self._items

    def begin(self, input: tuple[int, int]) -> None:
        x, y = input
        correct = self._correct
        digits_x = [int(xi) for xi in reversed(str(x))]
        digits_y = [int(yi) for yi in reversed(str(y))]
        len_x = len(digits_x)
        len_y = len(digits_y)
        items = [
            (
                xi if correct() else rd.randint(10),
                yi if correct() else rd.randint(10),
                xe if correct() else rd.randint(len_x),
                ye if correct() else rd.randint(len_y),
            )
            for xe, xi in enumerate(digits_x)
            for ye, yi in enumerate(digits_y)
        ]

        # steps.append(" + ".join("(%d*%d)e(%d+%d)" % item for item in items))
        items = [(xi * yi, xe + ye) for (xi, yi, xe, ye) in items]

        self._steps.clear()
        self._items: list[tuple[int, int]] = items
    
    def step(self) -> bool | None:
        self._add_step(self._items)

        self._items, changed = self._split_items(self._items)
        if changed:
            return
        self._items, changed = self._merge_items(self._items)
        if changed:
            return
        
        return True
    
    def end(self) -> tuple[list[str], int]:
        if self._items:
            max_e = max(e for b, e in self._items)
            out = ['0'] * (max_e + 1)
            for (b, e) in self._items:
                out[e] = str(b)
            out.reverse()
            out = int(''.join(out).lstrip('0'))
        else:
            out = 0
        
        return self._steps.copy(), out

    @classmethod
    def _repr_items(cls, items: list[tuple[int, int]]):
        if items:
            return '= ' + " + ".join("%de%d" % item for item in items)
        else:
            return "= 0"
            
    def _add_step(self, items: list[tuple[int, int]]):
        completeness = self._completeness
        if completeness >= 1 or rd.rand() <= completeness:
            self._steps.append(self._repr_items(items))
        else:
            pass
    
    def _split_item(self, b: int, e: int):
        out: list[tuple[int, int]] = []
        while b:
            x = b % 10 if self._correct() else rd.randint(10)
            if x:
                out.append((x, e))
            e = e + 1 if self._correct() else e
            b = b // 10
        return out

    def _split_items(self, items: list[tuple[int, int]]):
        out: list[tuple[int, int]] = []
        changed = False
        for b, e in items:
            splits = self._split_item(b, e)
            changed = changed or len(splits) != 1 or splits[0] != (b, e)
            out.extend(splits)
        return out, changed
        
    def _merge_items(self, items: list[tuple[int, int]]):
        out: list[tuple[int, int]] = []
        changed = False
        maxe = max(e for b, e in items)
        for e in range(maxe + 1):
            bs = [b for (b, e_) in items if (e_ == e and self._correct())]
            b = sum(bs)
            changed = changed or (len(bs) > 1)
            if not self._correct():
                b = rd.randint(max(0, b-10), b+11)
            if b:
                out.append((b, e))
        return out, changed


class MultistepMultStyle(StepwiseStyle[tuple[int, int], str, int]):

    def repr_input(self, input: tuple[int, int], **kwargs) -> str:
        x, y = input
        return "%d * %d" % (x, y)
    
    def repr_outcome(self, outcome: int, **kwargs) -> str:
        return "Outcome: %d" % outcome


class RE:
    outcome = re.compile(r"Outcome\s*:\s*(\d+)")


def _outcome_reward(
    output: str,
    input: tuple[int, int],
    syntax_error: float = 0.,
    result_error: float = 0.,
    correct: float = 1.
) -> float:
    
    x, y = input
    assert isinstance(x, int) and isinstance(y, int)
    answer = str(x * y)
    
    m = re.search(RE.outcome, output)
    outcome = m.group(1) if m is not None else None
    if outcome is None:
        return syntax_error
    if outcome == answer:
        return correct
    else:
        return result_error


class MultistepRM(RewardModel):

    def __init__(
        self,
        syntax_error: float = 0.,
        result_error: float = 0.,
        correct: float = 1.,
    ):
        super().__init__()

        self.syntax_error: float = syntax_error
        self.result_error: float = result_error
        self.correct: float = correct

    def outcome_reward(self, _outcome: str, **references) -> float:

        input = references.get('input')
        if input is None:
            raise ValueError("reward model requires \"input\" for reference.")
        
        return _outcome_reward(_outcome, input, self.syntax_error, self.result_error, self.correct)
