from __future__ import annotations
import abc
import torch
from typing import Literal, Never


from core.model import LLM
from .markov_chain import ThoughtSteps, MkvChainGenReasoner
from core.tokenization import Vocabulary
from core.inference import Inference, StopTokens
from core.inference.utils import to_stop_seqs
from core.utils.buf import TokenBuffer, Seqs


class MDPGenReasoner(MkvChainGenReasoner[Never]):
    """The MDP-style Reasoner where the output step contains the predicted next state (next prompt).
    Therefore, the policy should additionally predict the next state.
    The transition function is implemented to simply fetch the state within the output step.
    """

    _impl = "mdp"

    def __init__(
        self,
        inference: Inference,
        context_length: int,
        max_steps: int,
        end_of_state: StopTokens | None = None,
        end_of_step: StopTokens | None = None,
        beg_of_thought: str | int | None = None,
        ellipsis_token: str | int | None = None,
        end_of_outcome: StopTokens | None = None,
        special_tokens: Vocabulary | None = None,
        require_thought: bool = False,
    ):
        
        super().__init__(
            inference, context_length, max_steps,
            end_of_step, beg_of_thought, ellipsis_token, end_of_outcome,
            special_tokens, require_thought
        )
        
        # Infer the missing special tokens from style, if given.
        if special_tokens is not None:
            if end_of_state is None:
                end_of_state = special_tokens.end_of_state
        
        self._end_of_state = self._get_stop_seqs(end_of_state)
        if len(self._end_of_state) == 0:
            raise ValueError("End-of-state is an empty list.")

    def _generate(self) -> None:
        self._set_event_here('output')
        self._infer(self._end_of_state + self._end_of_outcome)
        self._finish_context('output')
            
    def _transit(self):
        ctx = self.context
        t_output = ctx.when('output', default=ctx.lengths)
        t_next = ctx.find(self._end_of_step, which='last', since=t_output, default=t_output)
        terminated = self.terminated
        thinking = ~terminated
        terminated_next = ctx.contains(self._end_of_outcome,
                                       since=t_output, old=terminated)
        is_outcome = terminated_next & thinking
        _next = ctx.seg(start=t_next)
        self.terminated = terminated_next

        # update thought
        if (thought := self.thought) is not None:
            output = ctx.seg(start=t_output)
            thought.set_event(self._n_step, thinking)
            thought.append_from_(output, cond=thinking)
        
        # Prepare for the transition
        require_init = (thinking) & (ctx.when('step') < 0)
        if torch.any(require_init):
            # initialize context as [<thought> ...]
            self.session.empty_(require_init)
            if self._beg_of_thought is not None:
                ctx.append_(self._beg_of_thought, cond=require_init)
            if self._epllipsis is not None:
                ctx.append_(self._epllipsis, cond=require_init)
            ctx.set_event('step', mask=require_init)
        
        ctx.remove_event('output')
        to_transit = thinking & self.inference.stopped
        self.session.traceback('step', to_transit)
        ctx.set_event('outcome', is_outcome)
        ctx.append_from_(_next, cond=to_transit)
