import os
from pathlib import Path
from typing import Literal, Iterable, Any, Final, Callable, Mapping
import random

from core.reasoning import CoT, ReasoningTask, RewardModel, Evaluator
from core.reasoning.task import Split
from core.reasoning.reflection.protocols import (
    BinaryVerifiable,
    DetailVerifiable,
    RejectRatioEvaluated,
)

from . import impl
from .utils import (
    Multiplier,
    MultCoT,
    generate_instances,
    MultEvaluator,
    random_input,
)


type MultImpl = Literal["mdp-v0", "mdp-v1", "multistep"]


class Mult(ReasoningTask,
           BinaryVerifiable, DetailVerifiable, RejectRatioEvaluated):
    
    default_root = Path("data/mult/")

    type SizeMap = Mapping[tuple[int, int], int]

    def __init__(
        self,
        root: Path | str | None = None,
        split_sizes: dict[Split, SizeMap] | None = None,
        implementation: MultImpl | None = None,
        generate_new_data: Literal["auto"] | bool = "auto",
        noise: float = 0,
        completeness: float = 1,
        max_reduction: int | None = None,
    ):

        self.root = root
        self.split_sizes = split_sizes or {}
        self.generate_new_data: Literal["auto"] | bool = generate_new_data
        self.implementation: Final[MultImpl | None] = implementation

        # initialize multiplier
        multiplier: Multiplier
        if implementation is None:
            multiplier = impl.direct.direct_multiplier
            self.style = impl.direct.DirectMultStyle()
            self.data_options = dict(token_level=True)
        elif implementation == 'multistep':
            multiplier = impl.multistep.MultistepMultiplier(noise, completeness)
            self.style = impl.multistep.MultistepMultStyle()
            self.data_options = dict(
                token_level = True,
                step_level = True,
                predict_outcome = True,
                predict_init = True,
            )
        elif implementation == "mdp-v0":
            multiplier = impl.mdp_v0.MDPMultiplier(noise)
            self.style = impl.mdp_v0.MDPMultStyle()
            self.data_options = dict(
                token_level = False,
                step_level = False,
                state_level = True,
                policy_level = True,
                predict_outcome = True,
                predict_init = True,
            )
        elif implementation == "mdp-v1":
            multiplier = impl.mdp_v1.MDPMultiplier(noise, max_reduction)
            self.style = impl.mdp_v1.MDPMultStyle()
            self.data_options = dict(
                token_level = False,
                step_level = False,
                state_level = True,
                policy_level = True,
                predict_outcome = True,
                predict_init = True,
            )
        else:
            raise NotImplementedError(implementation)

        self.multiplier = multiplier
    
    def _generate_instances(self, split: str, save_to_file: bool = False):
        ns = self.split_sizes.get(split)
        if ns is None:
            raise ValueError(f"The generation size of {split} data is not configured.")
        data = []
        for (d1, d2), n in ns.items():
            samples = generate_instances(self.multiplier, d1, d2, n, verbose=1)
            data.extend(samples)
        if save_to_file:
            os.makedirs(self.root, exist_ok=True)
            f = self.root / f'{split}.cot.json'
            print(f"Saving {len(data)} samples to {f}: ", end='')
            self.style.save_instances(data, f)
            print("Done!")
        return data
    
    def get_instances(self, split: str) -> list[MultCoT]:
        if self.generate_new_data is True:
            return self._generate_instances(split, save_to_file=False)
        else:
            f = self.root / f'{split}.cot.json'
            try:
                data = MultCoT.load_instances(f)
                print(f"Loaded {len(data)} {split} data from {f}.")
            except FileNotFoundError:
                if self.generate_new_data == "auto":
                    print(f"Generating {split} data to \"{f}\".")
                    data = self._generate_instances(split, save_to_file=True)
                else:
                    raise
            return data
    
    def sample(self, split: Split, require_thought: bool = False) -> CoT:
        return self.samples(split, 1, require_thought)[0]
    
    def samples(self, split: Split, n: int, require_thought: bool = False) -> list[CoT]:
        sizemap = self.split_sizes.get(split)
        if sizemap is None:
            raise ValueError(f"The generation size of \"{split}\" data is not configured.")
        digits = list(sizemap.keys())
        weights = list(sizemap.values())
        digits = random.choices(digits, weights, k=n)

        multiplier: Multiplier
        if require_thought:
            multiplier = self.multiplier
        else:
            multiplier = lambda input: (None, input[0] * input[1])

        cots: list[CoT] = []
        for dx, dy in digits:
            input = random_input(dx, dy)
            thought, outcome = multiplier(input)
            cots.append(MultCoT(input, thought, outcome))
        return cots

    def get_ref_dict(self, cot: CoT) -> dict[str, Any]:
        if not isinstance(cot, MultCoT):
            raise TypeError
        return {'input': cot.input, 'digits': cot.digits}

    def reward_model(
        self,
        supervision: Literal['outcome', 'process', 'success'],
    ) -> RewardModel:
        
        if self.implementation is None:
            RM = impl.direct.DirectRM
            return RM(0., 0., 1.)
        elif self.implementation == "multistep":
            RM = impl.multistep.MultistepRM
            if supervision == 'outcome':
                rm = RM(-1., 0., 1.)
            elif supervision == 'success':
                rm = RM(0., 0., 1.)
            elif supervision == 'process':
                raise NotImplementedError
            else:
                assert False
        elif self.implementation == "mdp-v0":
            RM = impl.mdp_v0.MDPMultRM
            if supervision == 'outcome':
                rm = RM(-1, -1, 0, 1)
                rm.disable_prm = True
            elif supervision == 'success':
                rm = RM(0, 0, 0, 1)
                rm.disable_prm = True
            elif supervision == 'process':
                rm = RM(-1, -1, 0, 1)
                rm.disable_orm = True
            else:
                assert False
        elif self.implementation == "mdp-v1":
            RM = impl.mdp_v1.MDPMultRM
            if supervision == 'outcome':
                rm = RM(-1, -1, 0, 1)
                rm.disable_prm = True
            elif supervision == 'success':
                rm = RM(0, 0, 0, 1)
                rm.disable_prm = True
            elif supervision == 'process':
                rm = RM(0, 0, 0, 1)
                rm.disable_orm = True
            else:
                assert False
        else:
            raise NotImplementedError(self.implementation)
        
        return rm

    def binary_verifier(self) -> Callable[[str, str], bool]:
        if self.implementation == "mdp-v0":
            parse = impl.mdp_v0.parse_context
            return lambda prompt, output: parse(prompt, output).correct
        elif self.implementation == "mdp-v1":
            parse = impl.mdp_v1.parse_context
            return lambda prompt, output: parse(prompt, output).correct
        else:
            raise NotImplementedError
        
    def detail_verifier(self, true: str, false: str) -> Callable[[str, str], str]:
        if self.implementation == "mdp-v1":
            assert isinstance(self.style, impl.mdp_v1.MDPMultStyle)
            return impl.mdp_v1.DetailChecker(self.style, true, false)
        else:
            raise NotImplementedError

    def evaluator(self, *,
                  cnt_in_result: bool = True,
                  rounding: int | None = None):
        return MultEvaluator(self.reward_model('success'), cnt_in_result, rounding)

    def reject_ratio_map(self, eval_results: Any) -> Callable[..., float]:
        return self.evaluator()._reject_rate_getter(eval_results)
