from __future__ import annotations
import torch
from typing import Literal, Self, Never

from core.model import LLM
from .generative import MTPReasoner
from ..formulation import CoT
from core.tokenization import Vocabulary
from core.inference import Context, Inference, StopTokens
from core.utils.buf import TokenBuffer, Seqs, StrArray


class ThoughtSteps(Context[int, Never, Never]):

    Final: Literal[-1] = -1

    def __iter__(self):
        i: int = 0
        while i in self:
            if torch.all(~self.happened(i)):
                break
            if i + 1 in self:
                yield self.eseg(i, i + 1, stop_default=self.lengths)
            else:
                yield self.eseg(i, None)
            i = i + 1


class MkvChainGenReasoner[Event](
    MTPReasoner[
        Event | Literal['output', 'step', 'outcome'],
        Literal['terminated'],
        Literal['thought'],
        ThoughtSteps | None,
    ],
):
    """The Markov Chain Reasoner that the output step is directly used as the next state (next prompt)."""
    
    _impl = "markov-chain"

    type _ReservedEvent = Literal['output', 'step', 'outcome']
    type _Event = _ReservedEvent | Event

    def __init__(
        self,
        inference: Inference,
        context_length: int,
        max_steps: int,
        end_of_step: StopTokens | None = None,
        beg_of_thought: str | int | None = None,
        ellipsis_token: str | int | None = None,
        end_of_outcome: StopTokens | None = None,
        special_tokens: Vocabulary | None = None,
        require_thought: bool = False,
    ):
        
        super().__init__(inference, context_length, max_steps)

        self.require_thought = require_thought
        
        # Infer the missing special tokens from style, if given.
        if special_tokens is not None:
            if beg_of_thought is None:
                beg_of_thought = special_tokens.beg_of_thought
            if end_of_outcome is None:
                end_of_outcome = special_tokens.end_of_outcome or special_tokens.end_of_reasoning
            if ellipsis_token is None:
                ellipsis_token = special_tokens.ellipsis
            if end_of_step is None:  
                end_of_step = special_tokens.end_of_step or special_tokens.end_of_action

        self._epllipsis = self._as_seq(ellipsis_token)
        self._beg_of_thought = self._as_seq(beg_of_thought)
        if end_of_step is None:
            self._end_of_step: list[torch.Tensor] = []
        else:
            self._end_of_step = self._get_stop_seqs(end_of_step)
        self._end_of_outcome = self._get_stop_seqs(end_of_outcome)

    def _as_seq(self, token: str | int | None):
        if token is not None:
            seq, = self._get_stop_seqs(token)
            assert seq.ndim == 1
        else:
            seq = None
        return seq

    @property
    def terminated(self):
        return self.context.info['terminated']
    
    @terminated.setter
    def terminated(self, value: torch.Tensor):
        self.context.info['terminated'] = value
    
    @property
    def thought(self) -> ThoughtSteps | None:
        thought = self.session.cache.get('thought')
        assert thought is None or isinstance(thought, ThoughtSteps)
        return thought

    def _launch(self, input: TokenBuffer) -> None:
        super()._launch(input)

        ctx = self.context
        ctx.info['terminated'] = ctx.make_flag(False)

        if self.require_thought:
            thought = ThoughtSteps(
                torch.full(
                    (*ctx.shape, ctx.max_length * self.max_step),
                    self.session.pad_index,
                    device=ctx.device,
                    dtype=torch.int32
                ),
                torch.zeros(ctx.shape, device=ctx.device, dtype=torch.int64),
            )
            self.session.cache['thought'] = thought

        # prompt: [INPUT + <beg-of-thought>]
        if self._beg_of_thought is not None:
            ctx.append_(self._beg_of_thought)

    def _set_event_here(self, e: _Event):
        """set event for only non-terminated cases."""
        self.context.set_event(e, ~self.terminated)

    def _infer(self, stop_seqs: Seqs):
        self.inference.stopped = self.terminated
        return super()._infer(stop_seqs)
    
    def _generate(self) -> None:
        self._set_event_here('output')
        self._infer(self._end_of_outcome + self._end_of_step)
        self._finish_context('output')
    
    def _transit(self):
        ctx = self.context
        terminated = self.terminated
        thinking = ~terminated
        output = ctx.eseg('output', None)
        terminated_next = output.contains(self._end_of_outcome, old=terminated)
        is_outcome = terminated_next & thinking
        self.terminated = terminated_next
            
        # update thought
        if (thought := self.thought) is not None:
            thought.set_event(self._n_step, thinking)
            thought.append_from_(output, cond=thinking)
        
        # Prepare for the next generation
        require_init = (thinking) & (ctx.when('step') < 0)
        if torch.any(require_init):
            # initialize context as [<thought> ...]
            self.session.empty_(require_init)
            if self._beg_of_thought is not None:
                ctx.append_(self._beg_of_thought, cond=require_init)
            if self._epllipsis is not None:
                ctx.append_(self._epllipsis, cond=require_init)
            ctx.set_event('step', mask=require_init)
        
        assert ctx.happened('step').all()
        to_transit = thinking & self.inference.stopped  # if not stopped, the context has overflown
        ctx.remove_event('output')
        self.session.traceback('step', to_transit)
        ctx.set_event('outcome', is_outcome)
        ctx.append_from_(output, cond=to_transit)  # [<thought> ... STEP | <thoght> ... </thought> OUTCOME]

    def _extract(self) -> tuple[ThoughtSteps | None, TokenBuffer]:
        return self.thought, self.context.eseg('outcome', None)
    
    def detokenize(
        self,
        cot: CoT[TokenBuffer, ThoughtSteps | None, TokenBuffer],
        skip_special_tokens: bool = False,
    ) -> CoT[StrArray, StrArray | None, StrArray]:
        llm = self.llm
        return CoT(
            llm.detokenize(cot.input, skip_special_tokens=skip_special_tokens),
            None if cot.thought is None else llm.detokenize(
                cot.thought,
                skip_special_tokens=skip_special_tokens
            ),
            llm.detokenize(cot.outcome, skip_special_tokens=skip_special_tokens)
        )
