from __future__ import annotations

import abc
import torch
import dataclasses as dc
from typing import Callable, Any, Sequence, cast, Final, Literal, Never, Iterable
from core.model import LLM
from .utils import StopTokens, to_stop_seqs
from core.utils.buf import TokenBuffer, Seqs, check_suffix, SequencesLike
from core.model import LLM, Preprocessor, GPT
from core.utils import kv
from core.utils.th import TensorDict, NestedTensorList, NestedTensorDict, TensorCollection, TensorDataClass
from contextlib import contextmanager


type InPlaceDecision = None
type _TensorTransform = Callable[[torch.Tensor], torch.Tensor]
type _TensorOperation = Callable[[torch.Tensor], Any]


@dc.dataclass
class Context[KEvent, KInfo, KData](TokenBuffer):
    """A context is an augmented version of buffer, which stores not only the sequences of tokens,
    but also related events and data.
    """

    __sequential__ = {'data'}

    events: TensorDict[KEvent] = dc.field(default_factory=TensorDict)
    """Tensors with position (index) of specific events."""

    info: TensorDict[KInfo] = dc.field(default_factory=TensorDict)
    """Tensors that have shape `(*shape, *elem_shape)`, such as flags and conditions."""
    
    data: TensorDict[KData] = dc.field(default_factory=TensorDict)
    """Tensors that have shape `(*shape, max_length, *elem_shape)`"""

    def __post_init__(self):
        super().__post_init__()

    def eseg(
        self, 
        start: KEvent | None,
        stop: KEvent | None,
        min_length: int = 0,
        start_default: torch.Tensor | int = -1,
        stop_default: torch.Tensor | int = -1, 
        pad_index: int = 0,
    ):
        
        start_pos = 0 if start is None else self.when(start, default=start_default)
        stop_pos = self.lengths if stop is None else self.when(stop, default=stop_default)

        if isinstance(start_pos, torch.Tensor):
            start_pos = torch.where(start_pos < 0, self.lengths, start_pos)
        if isinstance(stop_pos, torch.Tensor):
            stop_pos = stop_pos.masked_fill(stop_pos < 0, 0)

        return self.seg(start_pos, stop_pos, min_length, pad_index=pad_index)
    
    def when(self, event: KEvent, default: int | torch.Tensor = -1):
        try:
            t = self.events[event]
        except KeyError:
            if isinstance(default, int):
                return torch.full_like(self.lengths, default)
            else:
                return default.expand(self.shape)
        if isinstance(default, int) and default == -1:
            return t
        else:
            return torch.where(t < 0, default, t)
        
    def happened(self, event: KEvent) -> torch.Tensor:
        return self.when(event) >= 0
    
    def set_event(self, event: KEvent, 
                  mask: torch.Tensor | None = None,
                  when: torch.Tensor | int | None = None):
        if when is None:
            when = self.lengths
        elif isinstance(when, int):
            when = torch.full_like(self.lengths, when)
        elif when.shape != self.shape:
            when = when.expand(self.shape)
        if mask is None:
            self.events[event] = when.clone()
        else:
            old_when = self.when(event)
            new_when = torch.where(mask, when, old_when)
            self.events[event] = new_when

    def remove_event(self, event: KEvent, mask: torch.Tensor | None = None):
        events = self.events
        if event not in events:
            return
        if mask is None:
            del events[event]
        else:
            events[event] = events[event].masked_fill(mask, -1)

    def __contains__(self, event: KEvent):
        return event in self.events

    @staticmethod
    def empty(shape: Sequence[int], length: int, device=None, pad_index: int = 0):
        return Context[KEvent, KInfo, KData](
            torch.full((*shape, length), pad_index, device=device, dtype=torch.int32),
            torch.zeros(shape, device=device, dtype=torch.int64),
        )


type _KVSize = tuple[int, int]
"""(max_batch_size, max_seq_len)"""


@dc.dataclass
class _SessionState[KEvent, KInfo, KData, KCache](TensorDataClass):
    pos_write: int
    cache: NestedTensorDict[KCache, torch.Tensor | TensorCollection]
    context: Context[KEvent, KInfo, KData]
    model_pos: list[int] | Literal[0]
    model_kvs: NestedTensorList[kv.KVTensors] | None
    _prev_device: torch.device | None = None


class Session[KEvent, KInfo, KData, KCache]:
    """
    Manage the states of a decoding / inference / reasoning process. All state tensors will
    have consistent batch shapes: (*batch_shape, ...).

    TypeVars: `[KEvent, KInfo, KData, KCache]`

    It maintains:
    1. A context of tokens
    2. Kv-caches of models
    3. Runtime information of models on the context.
    4. State tensors of hooks, inference algorithms, and reasoning.
    """

    __context: Context[KEvent, KInfo, KData]
    __pos_write: int

    @dc.dataclass
    class ModelRuntime[Model: GPT]:

        model: Final[Model]
        session: Final[Session]
        device: Final[torch.device]
        outkey: Final[KData | None] = None
        """The key of data tensor that stores the output of this model."""
        lazy_kv_growth: bool = True
        max_batch_size: int = dc.field(init=False)
        max_seqlen: int = dc.field(init=False)
        pos: int = dc.field(init=False)

        def kv_reset(self, batch_size: int, seqlen: int):
            batch_size, seqlen = kv.reset(
                self.model, batch_size, seqlen,
                device=self.device, lazy=self.lazy_kv_growth
            )
            self.max_batch_size = batch_size
            self.max_seqlen = seqlen
            self.pos = 0
        
        def predict(
            self,
            tokens: torch.Tensor, 
            token_pos: torch.Tensor | None = None,
            batch_mask: torch.Tensor | None = None,
        ) -> torch.Tensor:
            batch_shape = tokens.shape[:-1]
            length = tokens.size(-1)
            kv.mask_kv_caches(self.model, batch_mask)
            if batch_mask is None:
                tokens = tokens.reshape(-1, length)
                output = self.model.forward(tokens, token_pos)
                output = output.view(*batch_shape, length, *output.shape[2:])
            else:
                assert batch_mask.shape == batch_shape
                tokens = tokens[batch_mask]
                output = self.model.forward(tokens, token_pos)
            return output

        def __call__(self, start: int, end: int, batch_mask: torch.Tensor | None) -> torch.Tensor:
            session = self.session
            context = session.context
            kout = self.outkey
            
            assert self.pos >= start
            assert start < end <= context.max_length

            # tokens: (*shape, len)
            tokens = context.tokens[..., start:end]
            token_pos = context.pos_idx[start:end]
            output = self.predict(tokens, token_pos, batch_mask)
            self.pos = max(self.pos, end)

            if kout is not None and (out := context.data.get(kout)) is not None:
                _idx = tuple(slice(None) for _ in range(context.ndim)) + (slice(start, end),)
                if batch_mask is None:
                    out[_idx] = output
                else:
                    out[_idx][batch_mask] = output

            return output 

    def __init__(self, context_length: int, pad_index: int = 0):
        self.context_length: int = context_length
        self.pad_index: int = pad_index

        self.__models: list[Session.ModelRuntime[GPT]] = []
        self.__cache: NestedTensorDict[KCache, torch.Tensor | TensorCollection] = NestedTensorDict()

    @property
    def pos_write(self):
        return self.__pos_write
    
    @property
    def cache(self):
        return self.__cache
    
    @cache.setter
    def cache(self, value: NestedTensorDict[KCache, torch.Tensor | TensorCollection]):
        self.__cache = value
        if __debug__:
            self._check_shape()
    
    def models(self):
        for m in self.__models:
            yield m.model
    
    def _get_all_state(
        self,
        model_state: Literal[None, 'kv', 'pos'] = 'pos',
        device: torch.device | str | None = None
    ) -> _SessionState[KEvent, KInfo, KData, KCache]:
        
        if device is not None and not isinstance(device, torch.device):
            device = torch.device(device)
        
        cache = self.__cache
        ctx = self.__context
        model_pos = 0
        kvs = None
        prev_device = None
        
        if model_state == "pos" or model_state == "kv":
            model_pos = [runtime.pos for runtime in self.__models]
        if model_state == 'kv':
            kvs = NestedTensorList(
                kv.KVTensors.extract_from(runtime.model, ctx.shape, slice(runtime.pos),
                                          device, clone=True)
                for runtime in self.__models
            )
       
        # move the cache and context to backgraond device.
        if device is not None:
            if not isinstance(device, torch.device):
                device = torch.device(device)
            prev_device = self.context.device
            cache = cache.to(device=device)
            ctx = ctx.to(device=device)

        return _SessionState(self.__pos_write, cache, ctx, model_pos, kvs, prev_device)
    
    def _restore_state(self, s: _SessionState, update_cache=False):
        device = s._prev_device
        cache = s.cache if device is None else s.cache.to(device=device)
        context = s.context if device is None else s.context.to(device=device)

        if update_cache:
            cache = NestedTensorDict(cache | self.__cache)
        
        self.__cache = cache
        self.__context = context
        self.__pos_write = s.pos_write

        if isinstance(poslist := s.model_pos, list): 
            for runtime, pos in zip(self.__models, poslist):
                runtime.pos = pos
        else:
            assert s.model_pos == 0
            for runtime in self.__models:
                runtime.pos = 0
        if (kvs := s.model_kvs) is not None:
            for runtime, kv in zip(self.__models, kvs):
                runtime.pos = pos
                kv.insert_to(runtime.model)
        
        assert self._check_shape()
    
    def fork(
        self,
        cache: Iterable[KCache] | bool = False,
        data: Iterable[KData] | bool = False,
        info: Iterable[KInfo] | bool = False,
        tokens: bool = False,
        event: Iterable[KEvent] | bool = False,
        state_device: torch.device | str | None = None,
        restore_model: Literal[None, 'kv', 'pos'] = 'pos',
    ):
        """
        Fork the inference process. 
        The current states, such as cache, context, and events, will be reserved and returned.
        The new state (e.g., cache, data, ...) will be cloned from the current state,
        if the corresponding argument is `True` or the keys of cloned states.
        """

        reserved_state = self._get_all_state(model_state=restore_model, device=state_device)

        # handle arguments and make new state
        def get_keys[Key](src: Iterable[Key], keys: Iterable[Key] | bool) -> Iterable[Key]:
            if keys is True:
                return src
            elif keys is False:
                return []
            else:
                return keys
        
        # override states
        _cache = self.__cache
        _ctx = self.__context
        cache = get_keys(_cache, cache)
        data = get_keys(_ctx.data, data)
        info = get_keys(_ctx.info, info)
        event = get_keys(_ctx.events, event)
        self.__context = _ctx.__class__(
            _ctx.tokens.clone() if tokens else _ctx.tokens,
            _ctx.lengths.clone(),
            TensorDict({k: _ctx.events[k].clone() for k in event}),  # events
            TensorDict({k: _ctx.info[k].clone() for k in info}),  # info
            TensorDict({k: _ctx.data[k].clone() for k in data}),  # data
        )
        self.__cache = NestedTensorDict({k: _cache[k].clone() for k in cache})
        del _ctx, _cache

        return reserved_state

    @contextmanager
    def tryfork(
        self,
        cache: Iterable[KCache] | bool = False,
        data: Iterable[KData] | bool = False,
        info: Iterable[KInfo] | bool = False,
        tokens: bool = False,
        event: Iterable[KEvent] | bool = False,
        state_device: torch.device | str | None = None,
        restore_model: Literal[None, 'kv', 'pos'] = 'pos',
        update_cache: bool = False,
    ):
        """
        fork the inference process. After the managed code block, the stored states will be restored.
        """

        reserved_state = self.fork(cache, data, info, tokens, event, state_device, restore_model)
        try:
            yield reserved_state
        finally:
            self._restore_state(reserved_state, update_cache=update_cache)

    def manage_model[Model: GPT](
        self, model: Model, device: torch.device, kdata: KData | None = None
    ) -> ModelRuntime[Model]:
        """
        Subsume a model into management.
        """

        for runtime in self.__models:
            if runtime.model is model:
                assert runtime.device == device
                assert runtime.outkey == kdata
                return cast(Session.ModelRuntime[Model], runtime)
        
        runtime = Session.ModelRuntime(model, self, device, kdata)
        self.__models.append(runtime)
        return runtime

    def apply_transform(self, fn: _TensorTransform):
        """
        Apply a transform to all managed tensors, including the KV-cache of models,
        the context, and cache.
        """

        for runtime in self.__models:
            if runtime.pos > 0:
                pos = slice(runtime.pos)
                kv.kv_apply(runtime.model, fn, self.__context.shape, pos)
        
        self.__context = self.__context.apply_transform(fn)
        self.__cache = self.__cache.apply_transform(fn)

        if __debug__:
            self._check_shape()
    
    def _check_shape(self):
        return self.__context._check_shape([self.__cache])
    
    def apply_operation(self, fn: _TensorOperation):
        """
        Apply an operation (in-place function) to all managed tensors, including the KV-cache
        of models, the context, and cache.
        """

        if __debug__:
            self._check_shape()

        for runtime in self.__models:
            if runtime.pos > 0:
                pos = slice(runtime.pos)
                kv.kv_inplace_apply(runtime.model, fn, self.__context.shape, pos)
        
        self.__cache.apply_operation(fn)
        self.__context.apply_operation(fn)

    def traceback(
        self,
        event: KEvent,
        cond: torch.Tensor | None = None,
        default: int | torch.Tensor = 0,
        clear_tokens: bool = False,
    ):
        ctx = self.__context
        when = ctx.when(event, default)
        if cond is not None:
            when = torch.where(cond, when, ctx.lengths)

        if not torch.all((when >= 0) & (when <= ctx.lengths)):
            raise IndexError

        ctx.lengths[:] = when
        # Any later events are cancelled
        for e, when in ctx.events.items():
            if e == event:
                continue
            ctx.events[e] = when.where(when <= ctx.lengths, -1)
        
        if clear_tokens:
            ctx.tokens[~ctx._mask()] = self.pad_index
        
        if cond is None or cond.any():
            self.relocate_pos_write(cond)

    def empty_(self, cond: torch.Tensor | None = None):
        ctx = self.__context
        if cond is None:
            ctx.lengths.zero_()
            ctx.events.clear()
            ctx.tokens[:] = self.pad_index
            self.relocate_pos_write()
        elif cond.any():
            cond = cond.expand(ctx.shape)
            ctx.lengths.masked_fill_(cond, 0)
            for e in ctx.events.values():
                e.masked_fill_(cond, -1)
            ctx.tokens[cond] = self.pad_index
            self.relocate_pos_write(cond)

    @property
    def context(self):
        try:
            return self.__context
        except AttributeError:
            raise RuntimeWarning("The context is not accessible as the session has not been prepared.")

    def prepare(self, ctx: Context):
        try:
            current = self.__context
            is_current = ctx is current
        except AttributeError:
            is_current = False

        if not is_current:
            for runtime in self.__models:
                runtime.pos = 0
            self.__context = ctx

        self.relocate_pos_write()
        
        if __debug__:
            self._check_shape()

    def release(self):
        try:
            del self.__context, self.__pos_write
        except AttributeError:
            pass
        
        self.__cache.clear()
        self.__models.clear()

    def empty(
        self,
        shape: Sequence[int] | int = (),
        context_length: int | None = None,
        device: torch.device | None = None,
    ) -> Context:

        if isinstance(shape, int):
            shape = (shape,)
        context_length = context_length or self.context_length
        context = Context.empty(shape, context_length, device, self.pad_index)
        return context
    
    def relocate_pos_write(self, mask: torch.Tensor | None = None):
        lengths = self.__context.lengths
        if mask is not None:
            lengths = lengths[mask]
        if torch.numel(lengths) > 0:
            self.__pos_write = int(lengths.min())
        else:
            self.__pos_write = self.__context.max_length
        for runtime in self.__models:
            runtime.pos = min(runtime.pos, self.__pos_write)

    def _write(self,
               next_tokens: torch.Tensor,
               write_mask: torch.Tensor,
               stopped: torch.Tensor,
               stop_seqs: Seqs) -> torch.Tensor:
        
        pos = self.__pos_write
        ctx = self.__context

        if pos >= ctx.max_length:
            raise IndexError
        
        ctx.tokens[..., pos][write_mask] = next_tokens[write_mask]
        ctx.lengths[write_mask] = pos + 1
        stopped = check_suffix(ctx.tokens, stop_seqs, pos + 1, stopped, write_mask)
        self.relocate_pos_write(~stopped)
        return stopped
    
    def _write_mask(self, active_mask: torch.Tensor | None = None):
        m = self.__context.lengths == self.__pos_write
        if active_mask is not None:
            m = m & active_mask
        return m


class Inference[KInfo, KData](abc.ABC):
    """
    Specifying the inference process that generates the outputs from inputs (the prompt) of a context.

    TypeVars: `[KInfo, KData]`
    """

    type _KInfo = KInfo | Literal['stopped', 'logprob']
    type _KData = KData | Literal["logits", "probs"]
    type _Session = Session[Never, _KInfo, _KData, Never]
    type _Context = Context[Any, _KInfo, _KData]

    __session: _Session | None = None

    @property
    def session(self):
        """
        The session that the inference is connected to.
        If the inference has not been connected to a session, `AttributeError` is raised.
        """
            
        if self.__session is None:
            raise AttributeError("The inference has not been connected to a session.")
        return self.__session
    
    @property
    def context(self):
        """
        The context that the inference is processing, equivalent to `self.session.context`.
        Raises AttributeError for no connected session.
        """
        return self.session.context
    
    def __init__(self, llm: LLM):
        self._llm = llm

    @property
    def stopped(self):
        """
        Whether the inference has stopped
        """

        return self.context.info['stopped']
    
    @stopped.setter
    def stopped(self, value: torch.Tensor | bool):
        context = self.context
        if isinstance(value, bool):
            context.info['stopped'] = context.make_flag(value)
        else:
            if value.shape != context.shape:
                raise IndexError
            context.info['stopped'] = value
    
    @property
    def logprob(self):
        return self.context.info.get('logprob')
    
    @logprob.setter
    def logprob(self, value: torch.Tensor | float):
        context = self.context
        if isinstance(value, torch.Tensor):
            if value.shape != context.shape:
                raise IndexError
            context.info['logprob'] = value
        else:
            context.info['logprob'] = context.make_tensor((), torch.float32, 0)
    
    def connect(self, session: Session):
        """
        Connect the inference to a session, which will be aware of the state and models of
        the inference. In turn, the session provides the access to the runtime information
        of the models.
        """

        if self.__session is not None:
            raise RuntimeError("The inference has already been connected.")
        
        self.__session = cast(Inference._Session, session)
        self.llm = self.__session.manage_model(self._llm.model, self._llm.device, 'logits')
    
    def release(self):
        """
        Release all resources and disconnect from the session.
        """
        self.__session = None
        del self.llm
    
    @abc.abstractmethod
    def launch(self, input: TokenBuffer) -> None:
        """
        Start a brand new inference process.
        1. Initialize the context of inference.
        2. Initialize the local state of inference.
        3. Preparing the models for the inference, including resetting the KV-cache of models.
        """

        self._llm.eval()
    
    def submit(self):
        """
        Finalize the inference. For example, in some inference algorithms,
        this involve choosing from several candidates. By default, nothing happens.
        """
        
        pass
    
    @abc.abstractmethod
    def infer_token(self) -> torch.Tensor:
        """
        Determine the token-level behavior of the inference.
        It predicts the next tokens of the entire batch.
        """
        raise NotImplementedError

    def _predict_logits(self, active: torch.Tensor | None):
        pos_r = self.llm.pos
        pos_w = self.session.pos_write
        pos_r = min(pos_w - 1, pos_r)

        if pos_r < 0:
            raise IndexError("Cannot read empty inputs.")

        return self.llm(pos_r, pos_w, active)

    def _step(self, stop_seqs: Seqs) -> bool | None:
        session = self.session
        ctx = session.context
        stopped = self.stopped

        if session.pos_write >= ctx.max_length:
            return True
        if torch.all(stopped):
            return True
        
        next_tokens = self.infer_token()
        write_mask = session._write_mask(~stopped)
        stopped = session._write(next_tokens, write_mask, stopped, stop_seqs)
        self.stopped = stopped

    def infer_sequence(self, stop_seqs: Seqs):
        if torch.all(stopped := self.stopped):
            return
        else:
            self.session.relocate_pos_write(~stopped)
        while True:
            if (stopped := self._step(stop_seqs)):
                break
    
    def preprocess(
        self, input_tokens: SequencesLike,
        input_length: torch.Tensor | int | None = None,
    ):
        preprocessor = self._llm.preprocessor
        pad_index = 0 if self.__session is None else self.__session.pad_index
        return TokenBuffer.from_sequences(input_tokens, input_length, preprocessor, pad_index)
    
    @torch.inference_mode()
    def generate(
        self,
        input_tokens: SequencesLike,
        input_lengths: torch.Tensor | int | None = None,
        max_output_length: int | None = None, 
        context_length: int | None = None,
        scale: tuple[int, ...] | int | None = None,
        scale_mode: Literal['repeat', 'expand'] = 'expand',
        require_logits: bool = False,
        stop_tokens: StopTokens | None = None,
        ignore_input_stop: bool = True,
    ) -> Context[Literal['output'], Literal['stopped'], Literal['logits']]:
        
        input = self.preprocess(input_tokens, input_lengths)
        if scale is not None:
            if isinstance(scale, int):
                scale = scale,
            if scale_mode == 'repeat':
                input = input.repeat(*scale)
            if scale_mode == 'expand':
                input = input.reshape(*input.shape, *(1 for _ in scale)).expand(*input.shape, *scale)

        if context_length is None:
            if max_output_length is None:
                context_length = self.llm.model.max_seq_length
            else:
                context_length = int(input.lengths.max()) + max_output_length
        else:
            if max_output_length is not None:
                raise ValueError("Confliction: `max_output_length` and `context_length`")

        session = Session(context_length)
        self.connect(session)
        stop_seqs = to_stop_seqs(self._llm.preprocessor, stop_tokens)
        self.launch(input)
        ctx = cast(Context[Literal['output'], Literal['stopped'], Literal['logits']], self.context)
        if require_logits:
            ctx.data['logits'] = ctx.make_sequence((self._llm.vocab_size,), torch.float, 0)
        ctx.set_event('output')
        if not ignore_input_stop:
            self.stopped = ctx.contains(stop_seqs)
        else:
            self.stopped = False
        self.infer_sequence(stop_seqs)
        session.release()
        self.release()

        return ctx
