"""
This module contains the inference classes to manage token-based inference processes.
"""


from .base import Inference, Context, Session
from .sampling import Sampling
from .search import BeamSearch, TokenVFBeamSearch, SeqVfBeamSearch
from .utils import StopTokens


import typing as _t
import core.model as _m
from pathlib import Path as _Path


Approach = _t.Literal[
    "sampling",
    "beam-lh",
    "beam-vf-token",
    "beam-vf-seq",
]


def make_inference(llm: _m.LLM, /, approach: Approach, **kwargs):
    """
    Create an inference class based on the specified approach.
    Args:
        llm: The LLM instance to use for inference.
        approach: The inference approach to use. Supported approaches are:
            - "sampling": Sampling-based inference.
            - "beam-lh" (deprecated): Beam search based on likelihood.
            - "beam-vf-token" (deprecated): Beam search based on a token-wise value function (verifier).
            - "beam-vf-seq" (deprecated): Beam search based on the value function of the entire sequence.
        **kwargs: Additional arguments specific to the chosen approach. Use `param.{approach}` to set the arguments correctly.
    """

    if approach == "sampling":
        return Sampling(llm, **kwargs)
    elif approach == "beam-lh":
        return BeamSearch(llm, **kwargs)
    elif approach == "beam-vf-token":
        return TokenVFBeamSearch(llm, **kwargs)
    elif approach == "beam-vf-seq":
        return SeqVfBeamSearch(llm, **kwargs)
    else:
        msg = f"\"{approach}\" is not a supported inference approach. The supported are" \
            + (','.join(map(str, Approach.__args__)))
        raise NotImplementedError(msg)


class param:
    
    @staticmethod
    def sampling(temperature: float = 1, top_k: int | None = None, top_p: float = 1):
        approach: Approach = "sampling"
        return locals()

    @staticmethod
    def beam_lh(k: int, n_choices: int | None = None, submit_best: bool = True):
        approach: Approach = "beam-lh"
        return locals()
    
    @staticmethod
    def beam_vf_token(
        k: int,
        n_choices: int | None = None,
        submit_best: bool = True,
        vf: _Path | str | None = None,
        vf_batch_size: int = 1,
        weight_vf: float = 1,
        weight_pr: float = 0
    ):
        approach: Approach = "beam-vf-token"
        return locals()

    @staticmethod
    def beam_vf_seq(
        k: int, n_choices: int,
        submit_best: bool = True,
        vf: _Path | str | None = None,
        weight_vf: float = 1., weight_pr: float = 0.,
        propose_size: int = 1,
        temperature: float = 1,
        top_k: int | None = None,
        top_p: float = 1
    ):
        approach: Approach = "beam-vf-seq"
        return locals()
