from .formulation import CoT, Reasoner
from .style import Style
from .rm import RewardModel
from .task import (
    ReasoningTask,
    Split as _Split,
    register_task,
    is_available_task_name,
    task_from_name,
    available_task_names
)
from .evaluator import Evaluator
from .reasoners import ThoughtImpl, GenerativeReasoner
from .reflection import (ReflectiveSample as _ReflectiveSample,
                         ReflDataCollector as _ReflectiveCollector,
                         Approach as ReflectionApproach)
from pathlib import Path as _Path
import typing as _t
import core.model as _m
from core.tokenization import Vocabulary as _Vocabulary
import core.inference as _inference


def make_reasoner(
    llm: _m.LLM,
    impl: ThoughtImpl,
    context_length: int,
    *,
    task: ReasoningTask | str | None = None,
    inference_args: dict | None = None,
    reflection_args: dict | None = None,
    max_steps: int | None = None,
    **kwargs,
) -> GenerativeReasoner:
    """
    Create a reasoner with the given LLM and task.

    Args:
        llm (LLM): The LLM for thought/outcome planning.
        impl (str): How the reasoning process is implemented:
            - "tokenwise": Token-wise reasoning.
            - "markov-chain": Markov-Chain-style reasoning.
            - "mdp": MDP-style reasoning.
        context_length (int): The context length of the LLM.
        task (ReasoningTask | str | None): The task to be used. If None, no task is used. This provides task-specific
            knowledge (e.g., oracle verifier and special tokens) required in certain approaches.
        inference_args (dict | None): Arguments for inference. If None, default arguments are used.
        reflection_args (dict | None): Arguments for reflection. If None, default arguments are used.
        max_steps (int | None): The maximum number of steps in the thought. If None, no limit is set.
        **kwargs: Additional arguments for the reasoning implementation.
    """


    from .reflection import get_reflection_class, Approach as ReflApporach
    from .reasoners import TokenReasoner, MkvChainGenReasoner, MDPGenReasoner
    
    def cast_arg[T](name: str, value, _type: type[T]) -> T:
        if not isinstance(value, _type):
            raise TypeError(
                f"Reasoning implementation \"{impl}\" requires argument \"{name}\" of type {_type}, "
                f"but received value is {value}."
            )
        return value
    
    if isinstance(task, str):
        task = task_from_name(task)
    
    if task is None:
        if llm.checkpoint_dir is not None:
            special_tokens = _Vocabulary.from_tokenizer_config(llm.checkpoint_dir)
        else:
            special_tokens = None
    else:
        special_tokens = task.special_tokens()

    if inference_args is not None:
        inference_args = inference_args.copy()
        inference_approach = inference_args.pop('approach', 'sampling')
        inference = _inference.make_inference(llm, inference_approach, **inference_args)
    else:
        inference = _inference.make_inference(llm, 'sampling')

    reflection_args = {} if reflection_args is None else reflection_args.copy()
    reflection_approach: ReflApporach | None = reflection_args.pop('approach', None)
    reflection_class = get_reflection_class(reflection_approach)
    reflection = reflection_class(**reflection_args)
    reflection.task = task

    if impl == 'tokenwise':
        cls = reflection.convert_reasoner(TokenReasoner)
        return cls(inference, context_length,
                   special_tokens=special_tokens, **kwargs)
    elif impl == 'markov-chain':
        max_steps = cast_arg('max_steps', max_steps, int)
        cls = reflection.convert_reasoner(MkvChainGenReasoner)
        return cls(inference, context_length, max_steps,
                   special_tokens=special_tokens, **kwargs)
    elif impl == 'mdp':
        max_steps = cast_arg('max_steps', max_steps, int)
        cls = reflection.convert_reasoner(MDPGenReasoner)
        return cls(inference, context_length, max_steps,
                   special_tokens=special_tokens, **kwargs)
    else:
        raise NotImplementedError(f"\"{impl}\" reasoning is not implemented.")


class split_config(_t.TypedDict):
    split: _Split
    n_samples: int | None
    n_instances: int | None


def collect_reflection_data(
    task: ReasoningTask | str,
    checkpoint: _Path | str,
    impl: ThoughtImpl,
    splits: _t.Iterable[_Split | split_config],
    devices: int | str = 1,
    num_nodes: int = 1,
    precision: _m.Precision | None = None,
    append_file: bool = True,
    out_dict: dict[_Split, list[_ReflectiveSample]] | None = None,
    out_path: _Path | str | None = None,
    file_fmt: _t.Literal['json', 'jsonl'] = 'json',
    verbose: int = 1,
    *,
    context_length: int,
    batch_size: int,
    max_steps: int | None = None,
    **kwargs,
) -> dict[_Split, list[_ReflectiveSample]]:
    from random import choices
    from core.model import make_llm, as_fabric
    from .reasoners import MTPReasoner
    from .reflection import ReflDataCollector
    
    if isinstance(task, str):
        task = task_from_name(task)
    
    llm = make_llm(checkpoint,
                   distribute=as_fabric(devices=devices,
                                        num_nodes=num_nodes,
                                        precision=precision))
    reasoner = make_reasoner(
        llm, impl, context_length,
        task=task,
        max_steps=max_steps,
    )
    if not isinstance(reasoner, MTPReasoner):
        raise ValueError("The reasoner must be a recursive instance.")
    
    if out_dict is None:
        out_dict = {}
    if out_path is not None:
        out_path = _Path(out_path)
        out_path.mkdir(exist_ok=True)
        if not out_path.is_dir():
            raise FileExistsError(f"{out_path} is not a directory.")

    collector = ReflDataCollector(task, reasoner, **kwargs, verbose=verbose)
    for split in splits:
        if isinstance(split, dict):
            n_instances = split['n_instances']
            n_samples = split['n_samples']
            split = split['split']
        else:
            n_instances = n_samples = None
        
        if (
            (n_instances is not None and n_instances <= 0) or
            (n_samples is not None and n_samples <= 0)
        ):
            continue

        instances = task.get_instances(split)
        if n_instances:
            instances = choices(instances, k=n_instances)
        data = out_dict.get(split, None)
        data = collector.__call__(batch_size, instances, n_samples, data)
        out_dict[split] = data

        if out_path is not None:
            out_file = out_path / f"{split}.{file_fmt}"
            ReflDataCollector.save_data(out_file, data, append=append_file)
    
    return out_dict


class param:

    from .reflection import param as __reflection

    class reflection(__reflection):

        eval = _ReflectiveCollector.EvalArgs

        class data_split(_t.TypedDict):
            split: _Split
            n_samples: int | None
            n_instances: int | None
        
        @staticmethod
        def collect(
            context_length: int,
            batch_size: int,
            n_branches: int,
            max_steps: int | None = None,
            propose_temperature: float = 1,
            solve_temperature: float = 1,
            rollout_temperature: float = 1,
            evaluation: eval = eval(),
            bg_device: str | None = "cpu",
        ):
            return locals()

        @staticmethod
        def self_verify(
            budget: int | None = None,
            context_length: int | None = None,
            reject_mode: _t.Literal["retry", "revise"] = "retry",
            max_retry: int | None = None,
            external_verifier: _t.Literal["vf", "oracle", "random"] | None = None,
            revise_temperature: float | None = None,
            reflect_temperature: float = 0,
            enable_statistics: bool = False,
            vf_path: _Path | str | None = None,
            vf_tolerance: float | None = None,
            random_reject_ratio: float | _Path | str = 0,  # used in random reflector 
        ):
            approach: ReflectionApproach = "self-verify"
            return locals()
