from __future__ import annotations

import abc
import torch
import dataclasses as dc
from typing import Literal, Final, Any, ClassVar, Callable, overload, cast, Protocol
from functools import partial

from ..formulation import CoT, Reasoner
from core.inference import Inference, Session, Context
from core.inference.utils import StopTokens, to_stop_seqs
from core.utils.buf import TokenBuffer, StrArray, Seqs
from core.utils.th import NestedTensorDict, TensorCollection, ObjectTensor


ThoughtImpl = Literal['tokenwise', 'stepwise', 'markov-chain', 'mdp']
"""How the thought process is implemented.
- `tokenwise`: append the next generated token to the previous state.
- `stepwise`: append the generated step to the previous state.
- `markov-chain`: the generated step is exactly the next state.
- `mdp`: The generated step contains an action sequence and exactly the next state.
"""


@dc.dataclass
class _ImplInfo:

    impl: ThoughtImpl
    multi_contextual: bool


_IMPL_INFO: dict[ThoughtImpl, _ImplInfo] = {
    f.impl: f for f in [
        _ImplInfo(
            impl='tokenwise',
            multi_contextual=False
        ),
        _ImplInfo(
            'stepwise',
            multi_contextual=False
        ),
        _ImplInfo(
            'markov-chain',
            multi_contextual=True,
        ),
        _ImplInfo(
            'mdp',
            multi_contextual=True,
        )
    ]
}


class Callbacks[KCache]:

    _host: GenerativeReasoner | None = None

    @property
    def host(self) -> GenerativeReasoner:
        if self._host is None:
            raise AttributeError("The host has not been assigned.")
        return self._host

    @property
    def cache(self) -> NestedTensorDict[KCache, torch.Tensor | TensorCollection]:
        return self.host.session.cache
    
    def get_local_state(self) -> Any:
        return None
    
    def load_local_state(self, state: Any) -> None: ...

    def attach(self, host: GenerativeReasoner):
        if self._host is not None:
            raise AttributeError("The hook has already been attached. Please detach first.")
        assert self not in host._hooks
        host._hooks.append(self)
        self._host = host
    
    def detach(self):
        if self._host is not None:
            self._host._hooks.remove(self)
        self._host = None

    def __call__(self, _callback: str, /, *args, **kwargs):
        """
        the default callback, used when `callback` is not explicitly defined as a function.
        """
        ...

    def __getattr__(self, _callback: str) -> Callable:
        return partial(self, _callback=_callback)


class GenHook(Protocol):
    def before_reasoning(self, input: TokenBuffer): pass
    def after_context(self, prompt_lengths: torch.Tensor): pass
    def after_reasoning(self, thought, outcome: TokenBuffer): pass
    def on_close(self): pass


type _RefFn = Callable[[tuple[int, ...]], dict[str, Any]]


class GenerativeReasoner[KEvent, KInfo, KCache, Thought](
    Reasoner[
        TokenBuffer,  # Input
        Thought,  # Thought
        TokenBuffer,  # Outcome
    ],
    abc.ABC,
):
    """
    TypeVars: `[KEvent, KInfo, KCache, Thought]`
    """

    _impl: ClassVar[ThoughtImpl]

    type SessionType = Session[KEvent, KInfo, Literal['logits'], KCache]
    type ContextType = Context[KEvent, KInfo, Literal['logits']]

    session: SessionType
    """
    The decoding session for text generation.
    """
    
    ref: _RefFn
    """The function to obtain reference information for supplementary functions,
    such as data collection, reward computing, evaluation, etc."""

    def __init__(self, inference: Inference, context_length: int):
        
        self.llm: Final = inference._llm
        self.session: Final = Session(context_length)
        self.inference: Final[Inference] = inference
        self._hooks: list[Callbacks[KCache]] = []
        self._eos: list[torch.Tensor] = to_stop_seqs(self.llm.preprocessor, None)
        self.preprocess: Final = inference.preprocess
        self.ref = lambda _: {}
        
    def _get_stop_seqs(self, s: StopTokens | None):
        return to_stop_seqs(self.llm.preprocessor, s)
    
    @overload
    def _as_seq(self, token: str | int) -> torch.Tensor: ...
    @overload
    def _as_seq(self, token: None) -> None: ...
    # implementation
    def _as_seq(self, token: str | int | None):
        if token is not None:
            seq, = self._get_stop_seqs(token)
            assert seq.ndim == 1
        else:
            seq = None
        return seq

    @property
    @abc.abstractmethod
    def terminated(self) -> torch.Tensor:
        raise NotImplementedError
    
    @terminated.setter
    def terminated(self, value: torch.Tensor) -> None:
        raise NotImplementedError
    
    @property
    def impl_info(self) -> _ImplInfo:
        return _IMPL_INFO[self._impl]

    @property
    def context(self) -> ContextType:
        return self.session.context
    
    def __call__(self, input: TokenBuffer) -> tuple[Thought, TokenBuffer]:
        self.before_reasoning(input)
        thought, outcome = self._solve()
        self.after_reasoning(thought, outcome)
        return thought, outcome
    
    def _infer(self, stop_seqs: Seqs):
        if stop_seqs:
            stop_seqs = self._eos + list(stop_seqs)
        else:
            stop_seqs = self._eos
        return self.inference.infer_sequence(stop_seqs)
    
    def before_reasoning(self, input: TokenBuffer):
        """
        Call `before_reasoning` hooks, and launch the inference.
        """
        
        self._launch(input)
        for hook in self._hooks:
            cast(GenHook, hook).before_reasoning(input)
    
    def after_reasoning(self, thought: Thought, outcome: TokenBuffer):
        """
        Call `after_reasoning` hooks, make necessary cleaning, and make ready for the next inputs.
        """
        for hook in self._hooks:
            cast(GenHook, hook).after_reasoning(thought, outcome)
        self.inference.release()
        self.session.release()

    @abc.abstractmethod
    def _solve(self) -> tuple[Thought, TokenBuffer]:
        """
        The main reasoning process. This would involve one or multiple `inference` stages.
        """
        ...

    def _launch(self, input: TokenBuffer) -> None:
        """
        1. Connect the inference with the session.
        2. Launch the inference.
        """
        self.inference.connect(self.session)
        assert self.inference.session is self.session
        self.inference.launch(input)
        
    def _finish_context(self, output: KEvent | torch.Tensor):
        """
        Signals hooks (e.g., collector and evaluator) that a context `(prompt -> output)` has finished
        by calling `after_context`.
        """

        if isinstance(output, torch.Tensor):
            when = output
        else:
            when = self.context.when(output)
        for hook in self._hooks:
            cast(GenHook, hook).after_context(when)

    def close(self):
        """
        Release all resources and call `on_close` hooks.
        """
        for hook in self._hooks:
            cast(GenHook, hook).on_close()
    
    def detokenize(
        self,
        cot: CoT[TokenBuffer, Thought, TokenBuffer],
        skip_special_tokens: bool = True,
    ) -> CoT[StrArray, Any, StrArray]:
        raise NotImplementedError

    def _get_local_state(self) -> dict[str, Any]:
        return {"__hooks__": [h.get_local_state() for h in self._hooks]}
    
    def _load_local_state(self, **state_vars) -> None:
        for hook, hook_state in zip(self._hooks, state_vars["__hooks__"]):
            hook.load_local_state(hook_state)


class MTPReasoner[KEvent, KInfo, KCache, Thought](GenerativeReasoner[KEvent, KInfo, KCache, Thought]):
    """A **recursive** implementation of reasoning based on **Markov thought process**."""

    _n_step: int
    """The number of recursive steps."""

    max_step: int
    """The maximal number of recursive steps."""

    def __init__(self, inference: Inference, context_length: int, max_step: int):
        super().__init__(inference, context_length)
        self.max_step = max_step
    
    def _solve(self) -> tuple[Thought, TokenBuffer]:
        self._generate()  # thought state (prompt) ---[generate]---> thought step
        self._transit()   # Applying the transition function: update the context as the next prompt
        self._n_step += 1
        if self._done():
            self.inference.submit()  # post-processing of inference algorithm
            thought, outcome = self._extract()
            return thought, outcome
        else:
            return self._solve()
    
    def _launch(self, input: TokenBuffer) -> None:
        super()._launch(input)
        self._n_step = 0

    @abc.abstractmethod
    def _generate(self) -> None: ...

    @abc.abstractmethod
    def _transit(self) -> None: ...
    
    def _done(self) -> bool:
        return self._n_step >= self.max_step or bool(torch.all(self.terminated))

    @abc.abstractmethod
    def _extract(self) -> tuple[Thought, TokenBuffer]: ...

    def _get_local_state(self):
        state = super()._get_local_state()
        state['n_step'] = self._n_step
        return state

    def _load_local_state(self, **state_vars) -> None:
        self._n_step = state_vars.pop('n_step')
        super()._load_local_state(**state_vars)
