import os
from pathlib import Path
from typing import Literal, Any, Final, Callable

from core.reasoning import CoT, ReasoningTask, RewardModel, Evaluator
from core.reasoning.task import Split
from core.reasoning.style import DirectStyle
from . import impl
from .utils import (
    SudokuSolver,
    SudokuCoT,
    generate_instances,
    SudokuEvaluator,
)
from core.reasoning.reflection.protocols import (
    BinaryVerifiable,
    DetailVerifiable,
    RejectRatioEvaluated,
)

type SudokuImpl = Literal["mdp-v0"]


class Sudoku(ReasoningTask,
             BinaryVerifiable, DetailVerifiable, RejectRatioEvaluated,):

    default_root = Path("data/sudoku/")

    def __init__(
        self,
        root: Path | str | None = None,
        split_sizes: dict[Split, dict[int, int]] | None = None,
        implementation: SudokuImpl | None = None,
        generate_new_data: Literal["auto"] | bool = "auto",
        max_solution_depth: int | None = None,
        max_solving_time: float | None = None,
    ):
        self.root = root
        self.split_sizes = split_sizes or {}
        self.generate_new_data: Literal["auto"] | bool = generate_new_data
        self.implementation: Final[SudokuImpl | None] = implementation
        self.max_solution_depth = max_solution_depth
        self.max_solving_time = max_solving_time

        sudoku_solver: SudokuSolver | None = None
        if implementation is None:  # direct reasoner
            self.style = DirectStyle()
            self.data_options = dict(token_level=True)
        elif implementation == "mdp-v0":
            sudoku_solver = impl.mdp_v0.SudokuMDPSolver(max_depth=self.max_solution_depth,
                                                        max_time=self.max_solving_time)
            self.style = impl.mdp_v0.MDPSudokuStyle()
            self.data_options = dict(
                token_level = False,
                step_level = False,
                state_level = True,
                policy_level = True,
                predict_outcome = True,
                predict_init = True,
            )
        else:
            raise NotImplementedError(implementation)
        
        self.sudoku_solver = sudoku_solver

    def _generate_instances(self, split: str, save_to_file: bool = False):
        ns = self.split_sizes.get(split)
        if ns is None:
            raise ValueError(f"The generation size of {split} data is not configured.")
        data = []
        for blanks, n in ns.items():
            samples = generate_instances(self.sudoku_solver, blanks, n, verbose=1)
            data.extend(samples)
        if save_to_file:
            os.makedirs(self.root, exist_ok=True)
            f = self.root / f'{split}.cot.json'
            print(f"Saving {len(data)} samples to {f}: ", end='')
            self.style.save_instances(data, f)
            print("Done!")
        return data

    def get_instances(self, split: str) -> list[SudokuCoT]:
        if self.generate_new_data is True:
            return self._generate_instances(split, save_to_file=False)
        else:
            f = self.root / f'{split}.cot.json'
            try:
                data = SudokuCoT.load_instances(f)
                print(f"Loaded {len(data)} {split} data from {f}.")
            except FileNotFoundError:
                if self.generate_new_data == "auto":
                    print(f"Generating {split} data to \"{f}\".")
                    data = self._generate_instances(split, save_to_file=True)
                else:
                    raise
            return data

    def get_ref_dict(self, cot: CoT) -> dict[str, Any]:
        return {'input': cot.input}

    def reward_model(
        self,
        supervision: Literal['outcome', 'process', 'success'],
    ) -> RewardModel:
        
        if self.implementation is None:
            RM = impl.mdp_v0.SudokuRM
            if supervision == 'process':
                raise NotImplementedError
            else:
                rm = RM()
                rm.disable_prm = True
                return rm
        elif self.implementation == "mdp-v0":
            RM = impl.mdp_v0.SudokuRM
            if supervision == 'outcome':
                rm = RM()
                rm.disable_prm = True
            elif supervision == 'success':
                rm = RM()
                rm.disable_prm = True
            elif supervision == 'process':
                rm = RM()
                rm.disable_orm = True
            else:
                assert False
        else:
            raise NotImplementedError(self.implementation)
        
        return rm
    
    def binary_verifier(self) -> Callable[[str, str], bool]:
        if self.implementation == "mdp-v0":
            parse = impl.mdp_v0.parse_context
            return lambda prompt, output: parse(prompt, output).correct
        else:
            raise NotImplementedError
    
    def detail_verifier(self, true: str, false: str) -> Callable[[str, str], str]:
        if self.implementation == "mdp-v0":
            return impl.mdp_v0.DetailChecker(true, false)
        else:
            raise NotImplementedError

    def evaluator(self, *,
                  cnt_in_result: bool = True,
                  rounding: int | None = None):
        return SudokuEvaluator(self.reward_model('success'), cnt_in_result, rounding)

    def reject_ratio_map(self, eval_results: Any) -> Callable[..., float]:
        return self.evaluator()._reject_rate_getter(eval_results)
