import torch
from typing import Literal, Any

from core.model import LLM
from .generative import MTPReasoner
from ..formulation import CoT
from core.inference import Inference, StopTokens
from core.tokenization import Vocabulary
from core.utils.buf import TokenBuffer, Seqs, StrArray


class TokenReasoner(
    MTPReasoner[
        Literal['output', 'thought', 'outcome'],  # Event
        Any,
        Any,
        TokenBuffer,  # Thought
    ],
):
    
    _impl = "tokenwise"

    def __init__(
        self,
        inference: Inference,
        context_length: int,
        beg_of_thought: str | int | None = None,
        end_of_thought: StopTokens | None = None,
        end_of_outcome: StopTokens | None = None,
        special_tokens: Vocabulary | None = None,
    ):
        
        super().__init__(inference, context_length, max_step=1)
        
        # Infer the missing special tokens from style, if given.
        if special_tokens is not None:
            if beg_of_thought is None:
                beg_of_thought = special_tokens.beg_of_thought
            if end_of_thought is None:
                end_of_thought = special_tokens.end_of_thought
            if end_of_outcome is None:
                end_of_outcome = special_tokens.end_of_outcome or special_tokens.end_of_reasoning
           
        self._beg_of_thought = self._as_seq(beg_of_thought)
        if end_of_thought is None:
            self._end_of_thought: list[torch.Tensor] = []
        else:
            self._end_of_thought = self._get_stop_seqs(end_of_thought)
        self._end_of_outcome = self._get_stop_seqs(end_of_outcome)
    
    @property
    def terminated(self):
        return self.inference.stopped
    
    @terminated.setter
    def terminated(self, value: torch.Tensor):
        self.inference.stopped = value
    
    def _launch(self, input: TokenBuffer) -> None:
        super()._launch(input)
        self.inference.stopped = False
        
        ctx = self.context
        ctx.set_event('thought')
        if self._beg_of_thought is not None:
            ctx.append_(self._beg_of_thought)
        ctx.set_event('output')

    def _generate(self) -> None:
        self._infer(self._end_of_outcome)

    def _transit(self):
        self._finish_context('output')
    
    def _extract(self) -> tuple[TokenBuffer, TokenBuffer]:
        ctx = self.context
        if self._end_of_thought:
            end_of_thought = ctx.find(
                self._end_of_thought,
                since=ctx.when('thought'),
                default=ctx.lengths
            )
            thought = ctx.seg(ctx.when('thought'), end_of_thought)
            outcome = ctx.seg(end_of_thought, None)
        else:
            thought = outcome = ctx.eseg('thought', None)

        return thought, outcome

    def detokenize(
        self,
        cot: CoT[TokenBuffer, TokenBuffer, TokenBuffer],
        skip_special_tokens: bool = True,
    ) -> CoT[StrArray, StrArray, StrArray]:
        llm = self.llm
        return CoT(
            llm.detokenize(cot.input, skip_special_tokens=skip_special_tokens),
            llm.detokenize(cot.thought, skip_special_tokens=skip_special_tokens),
            llm.detokenize(cot.outcome, skip_special_tokens=skip_special_tokens)
        )
