import numpy as np
import numpy.random as rd
import random
import re
import dataclasses as dc

from enum import StrEnum
from core.reasoning.formulation import MDPStep, AgentReasoner
from core.reasoning.style import MDPStyle, Segment
from core.reasoning.rm import RewardModel
from typing import Literal, NamedTuple, Any, cast, Callable
from ..utils import In


type Digits = list[int]


def n2d(number: int) -> Digits:
    return list(int(ch) for ch in reversed(str(number)))


def d2n(digits: Digits) -> int:
    assert all(0 <= d < 10 for d in digits)
    return int(''.join(map(str, reversed(digits))))


def repr_number(x: int):
    return ' '.join(str(x))


def noised(noise: float):
    if noise <= 0:
        return False
    else:
        return random.random() < noise

    
def _count_nonzero_digit(n: int):
    d = str(n)
    return len(d) - d.count('0')


def _positions(d: int, digits: list[int], noise: float, count: int | None = None) -> list[int]:
    positions = [i for i, di in enumerate(digits) if (di == d and not noised(noise))]
    if count is None:
        return positions
    else:
        return positions[:count]


@dc.dataclass(repr=False)
class _MultState:

    x: int
    y: int
    z: int

    def __repr__(self):
        return f"{repr_number(self.x)} * {repr_number(self.y)} + {repr_number(self.z)}"
        
    def reduce_x(self, d: int, noise: float, count: int | None) -> str:
        assert 0 <= d < 10
        dx = n2d(self.x)
        dy = n2d(self.y)
        lines = []
        space = " "

        lines.append(space.join(str(d)) + " * right:")
        prog: int = 0
        for yj in dy:
            z = d * yj
            if prog == 0:
                lines.append(f"- {repr_number(z)}")
            else:
                assert prog < 10
                lines.append(f"- {repr_number(z)} + {prog} = {repr_number(z + prog)}")
            prog = (z + prog) // 10
        if prog != 0:
            assert prog < 10
            lines.append(f"- {prog}")
        
        p = d * self.y
        for i in _positions(d, dx, noise, count):
            a = p * (10**i)
            lines.extend(self.update_z(a))
            dx[i] = 0
        self.x = d2n(dx)
        return "\n".join(lines)

    def reduce_y(self, d: int, noise: float, count: int | None) -> str:
        assert 0 <= d < 10
        dx = n2d(self.x)
        dy = n2d(self.y)
        lines = []
        space = " "
        lines.append("left * " + space.join(str(d)) + ":")
        prog: int = 0
        for xi in dx:
            z = xi * d
            if prog == 0:
                lines.append(f"- {repr_number(z)}")
            else:
                assert prog < 10
                lines.append(f"- {repr_number(z)} + {prog} = {repr_number(z + prog)}")
            prog = (z + prog) // 10
        if prog != 0:
            assert prog < 10
            lines.append(f"- {prog}")
        
        p = self.x * d
        for j in _positions(d, dy, noise, count):
            a = p * (10**j)
            lines.extend(self.update_z(a))
            dy[j] = 0
        self.y = d2n(dy)
        return "\n".join(lines)
    
    def update_z(self, a: int) -> list[str]:
        lines: list[str] = []
        lines.append(f"cumulate {repr_number(a)}:")
        dz = n2d(self.z)
        da = n2d(a)
        maxd = max(len(dz), len(da))
        dout: list[int] = []
        prog = 0
        for i in range(maxd):
            items: list[int] = []
            if i < len(dz):
                items.append(dz[i])
            if i < len(da):
                items.append(da[i])
            if prog:
                items.append(prog)
            s = sum(items)
            if len(items) <= 1:
                line = repr_number(s)
            else:
                expr = " + ".join(map(str, items))
                line = expr + " = " + repr_number(s)
            lines.append("- " + line)
            dout.append(s % 10)
            prog = s // 10
        while prog != 0:
            lines.append("- " + repr_number(prog))
            dout.append(prog % 10)
            prog = prog // 10
        self.z = d2n(dout)
        lines.append(f"get {repr_number(self.z)}.")
        return lines
    
    def terminated(self):
        return self.x == 0 or self.y == 0
    
    def get_outcome(self):
        return self.x * self.y + self.z


class _MDPAction(NamedTuple):

    fn: Literal['reduce_x', 'reduce_y', 'terminate']
    arg: Any = None


class MDPMultiplier(AgentReasoner[In, _MDPAction, list[MDPStep[str, str]], int]):

    def __init__(self, noise: float = 0, max_reduction: int | None = None):
        self._noise = noise
        self._max_reduction = max_reduction

        if noise:
            self._correct = lambda: rd.rand() >= noise
        else:
            self._correct = lambda: True
        self._thought: list[MDPStep[str, str]] = [] 
        self._state: _MultState
        
    @staticmethod
    def _digits_of(number: int) -> Digits:
        return list(int(ch) for ch in reversed(str(number)))
    
    def begin(self, input: tuple[int, int]) -> None:
        x, y = input
        self._state = _MultState(x, y, 0)
        self._thought.clear()
    
    def actor(self) -> _MDPAction:
        choices : list[_MDPAction] = []
        state = self._state
        dx = n2d(state.x)
        dy = n2d(state.y)
        
        if self._correct():
            if state.terminated():
                return _MDPAction("terminate")
            for i, dxi in enumerate(dx):
                if dxi != 0:
                    choices.append(_MDPAction("reduce_x", dxi))
            for j, dyj in enumerate(dy):
                if dyj != 0:
                    choices.append(_MDPAction("reduce_y", dyj))
            if choices:
                return random.choice(choices)
            else:
                return _MDPAction("terminate")
        else:
            choices.append(_MDPAction("terminate"))
            for dxi in range(10):
                choices.append(_MDPAction("reduce_x", dxi))
            for dyj in range(10):
                choices.append(_MDPAction("reduce_y", dyj))
            return random.choice(choices)
    
    def transit(self, action: _MDPAction) -> bool | None:
        self._thought.append({'state': repr(self._state), 'action': None})

        if action.fn == 'reduce_x':
            assert isinstance(action.arg, int)
            text = self._state.reduce_x(action.arg, self._noise, self._max_reduction)
        elif action.fn == 'reduce_y':
            assert isinstance(action.arg, int)
            text = self._state.reduce_y(action.arg, self._noise, self._max_reduction)
        elif action.fn == 'terminate':
            return True

        self._thought[-1]['action'] = text
        
    def end(self) -> tuple[list[MDPStep[str, str]], int]:
        outcome = self._state.get_outcome()
        return self._thought.copy(), outcome


class MDPMultStyle(MDPStyle[In, str, str, int]):

    state = Segment('<state>', '</state>', space='\n')
    action = Segment('<action>', '</action>', space='\n')
    
    def repr_input(self, input: In, **kwargs) -> str:
        x, y = input
        return "%d * %d" % (x, y)

    def repr_outcome(self, outcome: int, **kwargs) -> str:
        return "Outcome: %d" % outcome


class RE:
    white_space = re.compile(r"\s+")
    state = re.compile(r"<state>(.*)</state>")
    action = re.compile(r"<action>(.*)</action>")
    input = re.compile(r"<in>(\d+)\*(\d+)</in>")
    outcome = re.compile(r"<out>Outcome:(\d+)</out>")

    @staticmethod
    def remove_white_space(text: str):
        return re.sub(RE.white_space, '', text)


def parse_input(text: str):
    m = re.search(RE.input, RE.remove_white_space(text))
    if m is None:
        return None
    try:
        x, y = m.groups()
        return int(x), int(y)
    except ValueError:
        return None


def parse_state(text: str):
    # remove all white spaces
    m = re.search(RE.state, RE.remove_white_space(text))
    if m is None:
        return None
    else:
        text = m.group(1)
    try:
        items = text.split('+')
        x, y = items[0].split('*')
        x = int(x)
        y = int(y)
        if len(items) > 1:
            z = int(items[1])
        else:
            z = 0
        return _MultState(x, y, z)
    except ValueError:
        return None

    
def parse_outcome(text: str):
    m = re.search(RE.outcome, RE.remove_white_space(text))
    if m is None:
        return None
    try:
        outcome = m.group(1)
        outcome = int(outcome)
        return outcome
    except ValueError:
        return None


class _ParseContext(NamedTuple):
    correct: bool
    prompt_type: Literal["input", "state", None]
    output_type: Literal["state", "outcome", None]
    prompt_object: Any
    output_object: Any

    @classmethod
    def _extract_from_dict(cls, d: dict):
        return _ParseContext(*(d[f] for f in cls._fields))


def parse_context(prompt: str, output: str) -> _ParseContext:
    if (state := parse_state(prompt)) is None:
        if (input := parse_input(prompt)) is None:
            left = np.nan
            prompt_type = None
            prompt_object = None
        else:
            x, y = input
            left = x * y
            prompt_type = "input"
            prompt_object = input
    else:
        left = state.get_outcome()
        prompt_type = "state"
        prompt_object = state
    
    if (nextstate := parse_state(output)) is None:
        if (outcome := parse_outcome(output)) is None:
            right = np.nan
            output_type = None
            output_object = None
        else:
            right = outcome
            output_type = "outcome"
            output_object = outcome
    else:
        right = nextstate.get_outcome()
        output_type = "state"
        output_object = nextstate
    
    correct = left == right
    return _ParseContext(
        correct=correct,
        prompt_type=prompt_type,
        output_type=output_type,
        prompt_object=prompt_object,
        output_object=output_object
    )


def _to_reflect_segment(base: Segment):
    return base
    # deprecated
    if beg := base.beg_token:
        if m := re.match(r"<(.*)>", beg):
            beg = m.group(1)
        beg = f"<reflect:{beg}>"
    else:
        beg = None
    if end := base.end_token:
        if m := re.match(r"</(.*)>", end):
            end = m.group(1)
        end = f"</reflect:{end}>"
    else:
        end = None

    return Segment(beg, end, space=base.space, sep=base.sep)


class MDPMultRM(RewardModel):

    def __init__(
        self,
        syntax_error: float = -1,
        value_error: float = -1,
        correct_state: float = 0,
        correct_outcome: float = 1,
    ):
        super().__init__()

        self.syntax_error: float = syntax_error
        self.value_error: float = value_error
        self.correct_state: float = correct_state
        self.correct_outcome: float = correct_outcome

    def _parse_process(self, prompt: str, output: str) -> dict[str, Any]:
        return parse_context(prompt, output)._asdict()

    def process_reward(self, prompt: str, output: str, **references) -> float:
        ctx = _ParseContext._extract_from_dict(references)
        if ctx.correct:
            if ctx.output_type == 'outcome':
                return self.correct_outcome
            elif ctx.output_type == 'state':
                if self.correct_state == 0:
                    return 0
                nextstate = cast(_MultState, ctx.output_object)
                if ctx.prompt_type == 'input':
                    x, y = cast(tuple[int, int], ctx.prompt_object)
                elif ctx.prompt_type == 'state':
                    state = cast(_MultState, ctx.prompt_object)
                    x, y = state.x, state.y
                else:
                    assert False
                # base = len(str(x)) * len(str(y))
                d = _count_nonzero_digit(x) * _count_nonzero_digit(y)
                dnext = _count_nonzero_digit(nextstate.x) * _count_nonzero_digit(nextstate.y)
                r = self.correct_state * (d - dnext)
                return r
            else:
                assert False
        elif ctx.output_type is None or ctx.prompt_type is None:
            return self.syntax_error
        else:
            return self.value_error
    
    def abort_process(self, prompt: str, output: str, **references) -> bool:
        check = _ParseContext._extract_from_dict(references)
        return not check.correct
    
    def outcome_reward(self, _outcome: str, **references) -> float:
        input = references.get('input')
        if input is None:
            raise ValueError("reward model requires \"input\" for reference.")
        x, y = input
        assert isinstance(x, int) and isinstance(y, int)
        answer = x * y
        outcome = parse_outcome(_outcome)
        if outcome is None:
            return self.syntax_error
        elif outcome == answer:
            return self.correct_outcome
        else:
            return self.value_error


class DetailChecker:

    def __init__(
        self,
        style: MDPMultStyle | None = None,
        correct: str = "correct",
        error: str = "error"
    ):
        if style is None:
            style = MDPMultStyle()
        
        self._state = _to_reflect_segment(style.state)
        self._action = _to_reflect_segment(style.action)
        self._outcome = _to_reflect_segment(style.outcome)
        self._correct = correct
        self._error = error

    def _msg(self, correct: bool):
        return self._correct if correct else self._error
    
    def _check_init(self, x: int, y: int, output: str) -> str:
        if (state := parse_state(output)) is None:
            if (outcome := parse_outcome(output)) is None:
                return self._error
            else:
                return self._outcome(self._msg(x * y == outcome))
        else:
            msg_x = self._msg(state.x == x)
            msg_y = self._msg(state.y == y)
            msg_z = self._msg(state.z == 0)
            return self._state(f"{msg_x} * {msg_y} + {msg_z}")
    
    def _check_action(self, state: _MultState, action: str) -> str:
        if m := re.match(r"(\d)\*right:", action):
            d = int(m.group(1))
            text = action.removeprefix(m.group(0))
            if d == 0 or d not in n2d(state.x):
                return self._action(f"{self._error} * right")
            else:
                r, state.z, state.x = self._check_reduce(state.y, d, state.x, state.z, text)
                return self._action(f"{self._correct} * right:", r)
        elif m := re.match(r"left\*(\d):", action):
            d = int(m.group(1))
            text = action.removeprefix(m.group(0))
            if d == 0 or d not in n2d(state.y):
                return self._action(f"left * {self._error}:\n")
            else:
                r, state.z, state.y = self._check_reduce(state.x, d, state.y, state.z, text)
                return self._action(f"left * {self._correct}:\n" + r)
        else:
            return self._action(self._error)
    
    def _check_reduce(self, v: int, ui: int, u: int, z: int, action: str):
        """
        In `u * v` or `v * u`, check reducing digit `ui` from `u` and add the results to `z`.
        """
        
        splits = action.split("cumulate")
        prod = splits[0]
        lines: list[str] = []
        targets = list(_prod_targets(ui, v))
        lines.extend(self._check_items(targets, prod, _check_final_eq))
        
        p = v * ui
        for cumulate in splits[1:]:
            temp, z_, u_ = self._check_cum(z, p, ui, u, cumulate)
            lines.extend(temp)
            if z_ is None or u_ is None:
                break
            else:
                z, u = z_, u_
            del temp, z_, u_

        return '\n'.join(lines), z, u

    def _check_cum(self, z: int, p: int, ui: int, u: int, text: str):
        """
        assuming that reducing `ui` from `u` yields `p`, we check if `p`
        have been cumulated to `z` correctly.
        we check whether the cumulated value is `p*10^k` for some shift $k$ such that
        the k-th digit of `u` is `ui`. We also check if the reults of addition.
        This returns the new `z` and `u`.
        """

        def _count_tailed_zeros(a: str, b: str):
            n = len(b) - len(a)
            if b == a + "0" * n:
                return n
            else:
                return -1
        
        m = re.match(r"(.*):(.*?)get(\d*)", text)
        if m is None:
            return ["cumulate " + self._error], None, None

        du = n2d(u)
        lines = []
        a, items, new_z = m.groups()
        try:
            new_z = int(new_z)
        except ValueError:
            new_z = None
        try:
            da = n2d(int(a))
        except ValueError:
            return ["cumulate " + self._error], new_z, None
        j = _count_tailed_zeros(str(p), a)
        if 0 <= j < len(du) and du[j] == ui:
            du[j] = 0
            lines.append(f"cumulate {self._correct}:")
        else:
            lines.append(f"cumulate {self._error}:")

        targets = list(_add_targets(n2d(z), da))
        lines.extend(self._check_items(targets, items, _check_final_eq))
        lines.append(f"get {self._msg(new_z == int(a) + z)}.")

        return lines, new_z, d2n(du)
 
    def _check_items(self, targets: list[int], text: str,
                     item_checker: Callable[[str, int | None], bool],
                     sep: str = "-"):
        items = text.split(sep)[1:]
        lines = []
        for i, item in enumerate(items):
            target = None if i >= len(targets) else targets[i]
            lines.append(f"- {self._msg(item_checker(item, target))}")
        for i in range(len(items), len(targets)):
            lines.append(f"- {self._error}")
        return lines

    def __call__(self, prompt: str, output: str) -> str:
        if (state := parse_state(prompt)) is None:
            if (input := parse_input(prompt)) is None:
                return self._error  # The prompt is invalid
            else:
                x, y = input
                state = _MultState(x, y, 0)
        goal = state.get_outcome()
        lines: list[str] = []
        if (m := re.search(RE.action, RE.remove_white_space(output))) is not None:
            action = m.group(1)
            lines.append(self._check_action(state, action))
        if (nextstate := parse_state(output)) is None:
            if (outcome := parse_outcome(output)) is None:
                lines.append(self._error)  # The output is invalid
            else:
                lines.append(self._outcome(self._msg(outcome == goal)))
        else:
            lines.append(self._state("%s * %s + %s" % (
                self._msg(state.x == nextstate.x),
                self._msg(state.y == nextstate.y),
                self._msg(state.z == nextstate.z)
            )))
        return "\n".join(lines)


def _prod_targets(ui: int, v: int):
    prog = 0
    for d in n2d(v):
        z = ui * d + prog
        yield z
        prog = z // 10
    if prog:
        yield prog


def _check_final_eq(item: str, target: int | None):
    try:
        pred = int(item.split("=")[-1])
    except ValueError:
        pred = -1
    if target is None:
        target = 0
    return pred == target


def _add_targets(da: Digits, db: Digits):
    maxd = max(len(da), len(db))
    prog = 0
    for i in range(maxd):
        dai = da[i] if i < len(da) else 0
        dbi = db[i] if i < len(db) else 0
        s = dai + dbi + prog
        yield s
        prog = s // 10
    if prog:
        assert prog < 10
        yield prog