import numpy as np
import numpy.random as rd
import random
import re
import dataclasses as dc
from core.reasoning.formulation import MDPStep, AgentReasoner
from core.reasoning.style import MDPStyle, Segment
from core.reasoning.rm import RewardModel
import itertools
from typing import Literal, NamedTuple, Callable
from ..utils import In


type Digits = list[int]


def _digits_of(number: int) -> Digits:
    return list(int(ch) for ch in reversed(str(number)))

def _pop_zeros(digits: Digits):
    while digits and digits[-1] == 0:
        digits.pop()


@dc.dataclass(repr=False)
class _MultState:

    x: Digits
    y: Digits
    product: list[list[int]] = dc.field(init=False)
    result: int | None = None

    def __post_init__(self):
        self.product = [[] for _ in range(len(self.x) + len(self.y))]
        self.check_result()

    @staticmethod
    def _repr_numbers(numbers: list[int], pre='', sep='', suf='', empty='') -> str:
        number = sep.join(str(d) for d in reversed(numbers)) if numbers else empty
        return pre + number + suf
    
    def _repr_product(self, prefix='', sep='; ', suffix='') -> str:
        # remove empty item
        product = self.product.copy()
        while len(product) > 1 and not product[-1]:
            product.pop()

        return prefix + sep.join(
            (' + '.join(str(d) for d in items) if items else '0')
            for items in reversed(product)
        ) + suffix
    
    def __repr__(self) -> str:
        return ',\n'.join([
            self._repr_numbers(self.x, pre='x: ', sep=' ', empty='0'),
            self._repr_numbers(self.y, pre='y: ', sep=' ', empty='0'),
            self._repr_product('prod: ', sep=' | '),
        ])
    
    def reduce_x(self, i: int) -> str:
        xi = self.x[i]
        self.x[i] = 0
        _pop_zeros(self.x)

        temp: list[str] = []
        for j, yj in enumerate(self.y):
            shift = i + j
            z = xi * yj
            if z == 0:
                continue
            temp.append(str(z) + '0' * j)
            for k, zk in enumerate(_digits_of(z)):
                if zk != 0:
                    self.product[shift + k].append(zk)

        if temp:
            zs = ' + '.join(reversed(temp))
        else:
            zs = '0'
            
        action = "{xi} * y = {zs} @ 1{zeros}".format(
            xi=xi,
            zs=zs,
            zeros='0' * i,
        )

        return action
    
    def reduce_y(self, j: int) -> str:

        yj = self.y[j]
        self.y[j] = 0
        _pop_zeros(self.y)

        temp: list[str] = []
        for i, xi in enumerate(self.x):
            shift = i + j
            z = xi * yj
            if z == 0:
                continue
            temp.append(str(z) + '0' * i)
            for k, zk in enumerate(_digits_of(z)):
                self.product[shift + k].append(zk)

        if temp:
            zs = ' + '.join(reversed(temp))
        else:
            zs = '0'
            
        action = "x * {yj} = {zs} @ 1{zeros}".format(
            yj=yj,
            zs=zs,
            zeros='0' * j
        )

        return action
    
    def reduce_prod(self):
        sums = [sum(items) for items in self.product]
        while len(sums) > 1 and sums[-1] == 0:
            sums.pop()

        action = "prod = %s" % (' | '.join(map(str, reversed(sums))))
        
        for items in self.product:
            items.clear()

        for k, s in enumerate(sums):
            for i, num in enumerate(_digits_of(s)):
                if num != 0:
                    self.product[k + i].append(num)

        return action
    
    def check_result(self):
        prod_digits: Digits = []

        if any(self.x) and any(self.y):
            return

        for nums in self.product:
            if len(nums) >= 2:
                return
            if nums:
                assert nums[0] < 10
                prod_digits.append(nums[0])
            else:
                prod_digits.append(0)
        
        _pop_zeros(prod_digits)
        self.result = int(self._repr_numbers(prod_digits, empty='0'))
    

class _MDPAction(NamedTuple):

    fn: Literal['reduce_x', 'reduce_y', 'reduce_prod', 'pass']
    idx: int


class MDPMultiplier(AgentReasoner[In, _MDPAction, list[MDPStep[str, str]], int]):

    def __init__(self, noise: float = 0):
        self._noise = noise
        if noise:
            self._correct = lambda: rd.rand() >= noise
        else:
            self._correct = lambda: True

        self._thought: list[MDPStep[str, str]] = [] 
        self._state: _MultState

    @staticmethod
    def _digits_of(number: int) -> Digits:
        return list(int(ch) for ch in reversed(str(number)))
    
    def begin(self, input: tuple[int, int]) -> None:
        x, y = input
        self._state = _MultState(self._digits_of(x), self._digits_of(y))
        self._thought.clear()
    
    def actor(self) -> _MDPAction:
        choices : list[_MDPAction]

        state = self._state
        x = state.x
        y = state.y
        prod = state.product

        if self._correct():
            
            if state.result is not None:
                return _MDPAction('pass', -1)

            choices = []
            if any(len(items) > 1 for items in prod):
                choices.append(_MDPAction('reduce_prod', -1))
            if any(x) and any(y):
                choices.extend(_MDPAction('reduce_x', i) for i in range(len(x)) if x[i])
                choices.extend(_MDPAction('reduce_y', j) for j in range(len(y)) if y[j])
            if choices:
                return random.choice(choices)
        
        choices = [_MDPAction('reduce_prod', -1)] + list(itertools.chain(
            (_MDPAction('reduce_x', i) for i in range(len(x))),
            (_MDPAction('reduce_y', j) for j in range(len(y))),
        ))
        return random.choice(choices)
    
    def transit(self, action: _MDPAction) -> bool | None:
        self._thought.append({'state': repr(self._state), 'action': None})

        if action.fn == 'reduce_x':
            a = self._state.reduce_x(action.idx)
        elif action.fn == 'reduce_y':
            a = self._state.reduce_y(action.idx)
        elif action.fn == 'reduce_prod':
            a = self._state.reduce_prod()
        else:
            a = None

        self._state.check_result()
        self._thought[-1]['action'] = a
        if a is None and self._state.result is not None:
            return True
        
    def end(self) -> tuple[list[MDPStep[str, str]], int]:
        assert self._state.result is not None
        return self._thought.copy(), self._state.result


class MDPMultStyle(MDPStyle[In, str, str, int]):

    state = Segment('<state>', '</state>', space='\n')
    
    def repr_input(self, input: In, **kwargs) -> str:
        x, y = input
        return "%d * %d" % (x, y)

    def repr_outcome(self, outcome: int, **kwargs) -> str:
        return "Outcome: %d" % outcome


class RE:
    state = re.compile(r"<state>\s*x\s*:([\s\d]*),\s*y\s*:([\s\d]*),\s*prod\s*:([\s\d|+]*)\s*</state>")
    input = re.compile(r"<in>\s*(\d+)*\s*\*\s*(\d+)\s*</in>")
    outcome = re.compile(r"<out>\s*Outcome\s*:\s*(\d+)\s*</out>")


def parse_input(text: str):
    m = re.search(RE.input, text)
    if m is None:
        return None
    try:
        x, y = m.groups()
        return int(x), int(y)
    except ValueError:
        return None


def parse_state(text: str):
    m = re.search(RE.state, text)
    if m is None:
        return None
    try:
        x, y, prod = m.groups()
        x = int(''.join(x.split()))
        y = int(''.join(y.split()))
        prod = [sum(int(item) for item in prod_.split('+')) for prod_ in prod.split('|')]
        prod = sum(x * (10**i) for i, x in enumerate(reversed(prod)))
        return x, y, prod
    except ValueError:
        return None


def parse_outcome(text: str):
    m = re.search(RE.outcome, text)
    if m is None:
        return None
    try:
        outcome = m.group(1)
        outcome = int(outcome)
        return outcome
    except ValueError:
        return None


class _ParseContext(NamedTuple):
    correct: bool
    prompt_type: Literal["input", "state", None]
    output_type: Literal["state", "outcome", None]


def parse_context(prompt: str, output: str) -> _ParseContext:
    if (state := parse_state(prompt)) is None:
        if (input := parse_input(prompt)) is None:
            left = np.nan
            prompt_type = None
        else:
            x, y = input
            left = x * y
            prompt_type = "input"
    else:
        x, y, prod = state
        left = x * y + prod
        prompt_type = "state"
    
    if (state := parse_state(output)) is None:
        if (outcome := parse_outcome(output)) is None:
            right = np.nan
            output_type = None
        else:
            right = outcome
            output_type = "outcome"
    else:
        x, y, prod = state
        right = x * y + prod
        output_type = "state"
    
    correct = left == right
    return _ParseContext(correct, prompt_type, output_type)


class MDPMultRM(RewardModel):

    def __init__(
        self,
        syntax_error: float = -1,
        value_error: float = -1,
        correct_state: float = 0,
        correct_outcome: float = 1,
    ):
        super().__init__()

        self.syntax_error: float = syntax_error
        self.value_error: float = value_error
        self.correct_state: float = correct_state
        self.correct_outcome: float = correct_outcome

    def process_reward(self, prompt: str, output: str, **references) -> float:
        check = parse_context(prompt, output)
        if check.correct:
            if check.output_type == 'outcome':
                return self.correct_outcome
            elif check.output_type == 'state':
                return self.correct_state
            else:
                assert False
        elif check.output_type is None or check.prompt_type is None:
            return self.syntax_error
        else:
            return self.value_error
    
    def outcome_reward(self, _outcome: str, **references) -> float:
        input = references.get('input')
        if input is None:
            raise ValueError("reward model requires \"input\" for reference.")
        x, y = input
        assert isinstance(x, int) and isinstance(y, int)
        answer = x * y
        outcome = parse_outcome(_outcome)
        if outcome is None:
            return self.syntax_error
        elif outcome == answer:
            return self.correct_outcome
        else:
            return self.value_error
