import random
import copy
import numpy as np
import json
import re
import time
import dataclasses as dc
from typing import Literal, NamedTuple, List, Tuple, Dict, Set, Any, Iterable, Sequence
from numpy.typing import NDArray

from core.reasoning.formulation import MDPStep, AgentReasoner
from core.reasoning.style import MDPStyle, Segment
from core.reasoning.rm import RewardModel

from . import sudoku_utils as ut
from .sudoku_utils import Board, InvalidSudokuError



type In = str
type _MDPAction = Any


@dc.dataclass(repr=False)
class _SolverState:
    
    parent_state: "_SolverState | None"
    parent_action: str | None
    board: Board

    def __repr__(self) -> str:
        return ut.repr_board(self.board)
    
    def children(self, randomize_traversal: bool) -> Iterable["_SolverState"]:
        grid = ut.board_to_grid(self.board, fill_possibilities=True)

        try:
            _, grid = ut.reduce_grid(grid)
        except InvalidSudokuError:
            return
        
        reduced_board = ut.grid_to_board(grid)
        need_guess = np.all(reduced_board == self.board)
        action_lines: list[str] = []
        board_info = ut.board_subsets(self.board)
        
        for key in ['rows', 'cols', 'blocks']:
            action_lines.append(" ".join([key] + board_info[key]))

        if not need_guess:
            action_lines.append("reduce")
            action = "\n".join(action_lines)
            yield _SolverState(self, action, reduced_board)
        else:
            action_lines.append("guess")
            action = "\n".join(action_lines)

            # find an uncertain cell with the least possibilities
            unsolved = [cell for cell in range(81) if len(grid[cell]) > 1]
            assert unsolved  # always true
            guess_pos = min(unsolved, key=lambda c: len(grid[c]))
            candidates = list(grid[guess_pos])
            guess_pos = (guess_pos // 9, guess_pos % 9)
            if randomize_traversal:
                random.shuffle(candidates)
            for d in candidates:
                guessed_board = reduced_board.copy()
                guessed_board[guess_pos] = d
                yield _SolverState(self, action, guessed_board)

        return "\n".join(action_lines)

    def terminated(self):
        return np.all(self.board != 0)
    
    def path(self) -> list[MDPStep[str, str]]:
        reversed_path: list[MDPStep[str, str]] = [{"state": repr(self), "action": None}]
        state = self
        while state.parent_state is not None:
            reversed_path.append({"state": repr(state.parent_state), "action": state.parent_action})
            state = state.parent_state
        return list(reversed(reversed_path))


class SudokuMDPSolver:

    def __init__(self,
                 randomize_traversal: bool = True,
                 max_depth: int | None = None,
                 max_time: float | None = None):
        self._randomize_traversal = randomize_traversal
        self._max_depth: int | None = max_depth
        self._max_time = max_time or float('inf')

    def __call__(self, input: str) -> tuple[list[MDPStep[str, str]] | None, str | Literal["FAILURE"]]:
        state = _SolverState(None, None, ut.str_to_board(input))
        solution = self.search(0, state, time.time() + self._max_time)
        if solution is not None:
            thought = solution.path()
            outcome = thought[-1]['state'] 
            return thought, outcome
        else:
            return None, "FAILURE"

    def search(self, depth: int, state: _SolverState, finish_time: float):
        if state.terminated():
            return state
        elif (self._max_depth is not None and depth >= self._max_depth) or time.time() > finish_time:
            return None
        else:
            for child in state.children(self._randomize_traversal):
                term = self.search(depth + 1, child, finish_time)
                if term is not None:
                    return term
            return None


class MDPSudokuStyle(MDPStyle[str, str, str, str]):

    state = Segment('<state>', '</state>', space='\n')
    action = Segment('<action>', '</action>', space='\n')
    outcome = Segment('<out>', '</out>', space='\n')


class RE:
    
    white_space = re.compile(r"\s+")
    action = re.compile(r"<action>(.*)</action>", re.DOTALL)

    board_patterns = {
        segment: re.compile(rf"<{segment}>(.*)</{segment}>", re.DOTALL)
        for segment in ("state", "in", "out")
    }

    @staticmethod
    def remove_white_space(text: str):
        return re.sub(RE.white_space, '', text)
    

def parse_cells(text: str, segment: Literal["state", "in", "out"]) -> list[str] | None:
    pattern = RE.board_patterns[segment]
    m = re.search(pattern, text)
    if m is None:
        return None
    else:
        content = m.group(1)
        if segment == "in":
            return list(RE.remove_white_space(content))
        else:
            return content.split()


class _ParseContext(NamedTuple):

    prompt_type: Literal["input", "state", None]
    output_type: Literal["state", "outcome", None]
    prompt_cells: list[str] | None
    output_cells: list[str] | None
    
    @classmethod
    def _extract_from_dict(cls, d: dict):
        return _ParseContext(*(d[f] for f in cls._fields))

    def detailed_check(self):
        return _check_errors(self.prompt_cells or [], self.output_cells or [])
    
    @property
    def correct(self) -> bool:
        if self.prompt_cells is None or self.output_cells is None:
            return False
        else:
            errors = _check_errors(self.prompt_cells, self.output_cells)
            return not bool(errors)


class _ErrorInfo(NamedTuple):

    cells: NDArray[np.bool_]
    redundancy: int

    def __bool__(self):
        return bool(self.cells.any()) or self.redundancy > 0


def _check_errors(old: list[str], new: list[str]) -> _ErrorInfo:
    # pad the board str with '.' to length 81:

    def _pad_to_board(cells: list[str]):
        pad_cells = cells + ['.'] * (81 - len(cells)) if len(cells) < 81 else cells[:81]
        return ut.str_to_board(pad_cells, invalid=-1)  # `-1` for missing and invalid numbers
        
    old_board = _pad_to_board(old)
    new_board = _pad_to_board(new)
    
    # if a cell is determined in the old board, it must remain unchanged.
    altered = (new_board != old_board) & (old_board > 0)
    invalid = (new_board < 0)
    conflicts = ut.conflict_flags(new_board)

    cells: NDArray[np.bool_] = (altered | invalid | conflicts)
    redundancy = max(len(new) - 81, 0)
    
    return _ErrorInfo(cells, redundancy)


def parse_context(prompt: str, output: str) -> _ParseContext:
    if (state := parse_cells(prompt, "state")) is None:
        if (input := parse_cells(prompt, "in")) is None:
            prompt_type = None
            prompt_cells = None
        else:
            prompt_type = "input"
            prompt_cells = input
    else:
        prompt_type = "state"
        prompt_cells = state
    if (nextstate := parse_cells(output, "state")) is None:
        if (outcome := parse_cells(output, "out")) is None:
            output_type = None
            output_cells = None
        else:
            output_type = "outcome"
            output_cells = outcome
    else:
        output_type = "state"
        output_cells = nextstate
    
    return _ParseContext(
        prompt_type=prompt_type,
        output_type=output_type,
        prompt_cells=prompt_cells,
        output_cells=output_cells
    )


class SudokuRM(RewardModel):
    
    def _parse_process(self, prompt: str, output: str) -> Dict[str, Any]:
        ctx = parse_context(prompt, output)
        return {"parsed": ctx, "correct": ctx.correct}
    
    def outcome_reward(self, _outcome: str, **references) -> float:
        input = references["input"]
        assert isinstance(input, str) and len(input) == 81
        outcome = parse_cells(_outcome, "out")
        if outcome is None:
            correct = False
        else:
            errors = _check_errors(list(input), outcome)
            correct = not bool(errors)
        return 1. if correct else 0.

    def process_reward(self, prompt: str, output: str, **references) -> float:
        correct: bool = references["correct"]
        return 1. if correct else 0.
    
    def abort_process(self, prompt: str, output: str, **references) -> bool:
        correct: bool = references["correct"]
        return not correct


class DetailChecker:

    def __init__(
        self,
        correct: str = "correct",
        error: str = "error"
    ):
        self._correct = correct
        self._error = error

    def _flag2label(self, error: bool | np.bool_):
        return self._error if error else self._correct

    def __call__(self, prompt: str, output: str) -> str:
        ctx = parse_context(prompt, output)
        errors = ctx.detailed_check()
        assert errors.cells.shape == (9, 9)
        lines = []
        for row in range(9):
            lines.append(' '.join(map(self._flag2label, errors.cells[row, :])))
        if errors.redundancy > 0:
            lines.append(' '.join([self._error] * errors.redundancy))
        return '\n'.join(lines)
