from __future__ import annotations

import torch
from torch import Tensor
import abc
from dataclasses import dataclass, fields, Field, field
import numpy as np
import math
from typing import Sequence, Callable, Self, Any, Literal, Iterable, ClassVar, overload
from numpy.typing import NDArray
from litgpt.api import Preprocessor

from .th import TensorDataClass, TensorCollection, iterate_tensors
from . import th as _th
from . import iterate as _iter


type StrArray = NDArray[np.str_]
type SequencesLike = Tensor | str | list[str] | list[Tensor] | StrArray
type Seqs = Sequence[Tensor]


def check_suffix(
    tokens: Tensor,
    suffixes: Seqs,
    pos: int | None = None,
    old: Tensor | None = None,
    mask: Tensor | None = None,
):
    """
    check if there is a stop sequence ends at the i-th token.
    """

    if pos is None:
        pos = tokens.size(-1)
    if old is None:
        stopped = torch.zeros(tokens.shape[:-1], dtype=torch.bool, device=tokens.device)
    else:
        stopped = old.clone()

    for suffix in suffixes:
        len_seq = len(suffix)
        if pos < len_seq:
            continue
        matched = (tokens[..., pos-len_seq:pos] == suffix).all(dim=-1)
        if mask is None:
            stopped |= matched
        else:
            stopped |= (matched & mask)

    return stopped


class _BufferMeta(abc.ABCMeta):

    __sequential__: Iterable[str] = frozenset()

    def __init__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], /, **kwds):
        super().__init__(name, bases, attrs, **kwds)

        sequential = frozenset(cls.__sequential__)
        for base in bases:
            if isinstance(base, _BufferMeta):
                assert isinstance(base.__sequential__, frozenset)
                sequential = sequential | base.__sequential__
        
        cls.__sequential__ = sequential
  

@dataclass
class TokenBuffer(TensorDataClass, metaclass=_BufferMeta):

    __sequential__ = ['tokens']

    tokens: Tensor
    "Tensor (dtype=int): [*batch, max_length]"

    lengths: Tensor
    "Tensor (dtype=long): [*batch]"

    @property
    def device(self):
        return self.tokens.device
    
    @property
    def shape(self):
        return self.tokens.shape[:-1]
    
    @property
    def max_length(self):
        return self.tokens.size(-1)
    
    @property
    def ndim(self):
        return self.tokens.ndim - 1
    
    @property
    def nelem(self):
        return math.prod(self.shape)
        
    @property
    def pos_idx(self):
        self.__pos_idx: Tensor
        try:
            return self.__pos_idx
        except AttributeError:
            self.__pos_idx = self._arange()
            return self.__pos_idx
        
    def _arange(self, start: int | None = None, stop: int | None = None, step: int | None = None):
        if start is None:
            start = 0
        if stop is None:
            stop = self.max_length
        if step is None or step == 1:
            return torch.arange(start, stop, dtype=torch.int64, device=self.device)
        else:
            return torch.arange(start, stop, step, dtype=torch.int64, device=self.device)
    
    def __post_init__(self):
        assert self._check_shape(), "inconsistant shape."

    def iterate_sequences(self):
        for name in self.__sequential__:
            val = getattr(self, name, None)
            if isinstance(val, Tensor):
                yield val

    def _check_shape(
        self,
        external_tensors: Iterable[Tensor | TensorCollection] | None = None,
        external_sequences: Iterable[Tensor] | None = None,
    ):
        ndim = self.ndim
        shape = self.shape
        maxlen = self.max_length

        tensors = self.iterate_tensors() if external_tensors is None \
            else iterate_tensors((self, *external_tensors))
        sequences = self.iterate_sequences() if external_sequences is None \
            else iterate_tensors((self, *external_sequences))
        
        return (
            all(tensor.shape[:ndim] == shape for tensor in tensors) and
            all(tensor.size(ndim) == maxlen for tensor in sequences)
        )
        
    def _mask(
        self,
        start: Tensor | int = 0,
        stop: Tensor | int | None = None,
        clampped: bool = True,
        _range: range | slice[int | None, int | None, int | None] | None = None,
    ):
        """
        Generate a mask indicating whether each tokens in `self.tokens[..., _range]` (i.e. position `t` in `_range`) has an absolute position `t` in `[start, stop)`. By default, `_range` is `[0, max_length)`
        """

        if _range is None:
            pos = self.pos_idx
        else:
            pos = self._arange(_range.start, _range.stop, _range.step)

        if stop is None:
            stop = self.lengths
        elif clampped:
            stop = self.lengths.clamp_max(stop)
        if isinstance(stop, Tensor):
            stop = stop.unsqueeze(-1)

        if isinstance(start, int) and start <= 0:
            return pos < stop
        else:
            if isinstance(start, Tensor):
                start = start.unsqueeze(-1)
            return (pos >= start) & (pos < stop)

    def view(self, *shape: int) -> Self:
        d = self.ndim
        return self.apply_transform(lambda x: x.view(*shape, *x.shape[d:]))

    def reshape(self, *shape: int) -> Self:
        d = self.ndim
        return self.apply_transform(lambda x: x.reshape(*shape, *x.shape[d:]))
    
    def expand(self, *shape: int)-> Self:
        d = self.ndim
        return self.apply_transform(lambda x: x.expand(*shape, *x.shape[d:]))
    
    def repeat(self, *shape: int) -> Self:
        d = self.ndim
        
        if len(shape) < d:
            shape = (1,) * (d - len(shape)) + shape

        def fn(x: Tensor):
            ones = (1,) * (x.ndim - d)
            return x.repeat(shape + ones)
        
        return self.apply_transform(fn)
    
    def apply_sequential_transform(self, fn: Callable) -> Self:
        init_args = {}
        for name in self._init_fields():
            val = getattr(self, name)
            if name in self.__sequential__:
                if isinstance(val, Tensor):
                    val = fn(val)
                elif isinstance(val, TensorCollection):
                    val = val.apply_transform(fn)
            init_args[name] = val
        return self.__class__(**init_args)
    
    def extend(self, length: int, pad_index: int = 0) -> Self:
        if length < 0:
            raise ValueError
        d = self.ndim
        shape = self.shape

        def fn(x: Tensor):
            elemshape = x.shape[d + 1:]
            tail = torch.full((*shape, length, *elemshape), pad_index, dtype=x.dtype, device=x.device)
            return torch.cat((x, tail), dim=d)
        
        return self.apply_sequential_transform(fn)
    
    def gather(self, dim: int, idx: Tensor) -> Self:
        ndim = self.ndim
        if idx.ndim > ndim:
            raise IndexError
        elif idx.ndim < ndim:
            idx = idx.expand(self.shape[:ndim - idx.ndim] + idx.shape)
        assert idx.ndim == ndim 

        def gather_tensor(x: Tensor):
            elem_shape = x.shape[idx.ndim:]
            idx_ = idx.reshape(idx.shape + tuple(1 for _ in elem_shape)).expand(
                idx.shape + elem_shape
            )
            return torch.gather(x, dim, idx_)

        return self.apply_transform(gather_tensor)
    
    def tokens_at(self, *index: int, start: int = 0, stop: int | None = None, clamped: bool = True):
        if len(index) != len(self.shape):
            raise IndexError(f"the batch shape is {self.shape}, but the index provides {len(index)} axis.")
        
        if stop is None:
            stop = int(self.lengths[index])
        elif clamped:
            length = int(self.lengths[index])
            stop = min(length, stop)
        return self.tokens[index][start: stop]
    
    def truncate(self) -> Self:
        len = int(self.lengths.max())
        idx = (slice(None),) * self.ndim + (slice(len),)
        return self.apply_sequential_transform(lambda x: x[idx])
    
    def seg(
        self,    
        start: Tensor | int = 0,
        stop: Tensor | int | None = None,
        min_length: int = 0,
        clampped: bool = True,
        pad_index: int = 0,
    ):
        batch_shape = self.shape
        device = self.device

        if stop is None:
            stop = self.lengths
        elif clampped:
            stop = self.lengths.clamp_max(stop)
        
        new_lengths = stop - start
        if isinstance(new_lengths, int):
            new_lengths = max(0, new_lengths)
            new_lengths = torch.full_like(self.lengths, new_lengths)
        else:
            new_lengths = new_lengths.clamp_min(0)

        length = max(int(new_lengths.max()), min_length)
        out = TokenBuffer(
            torch.full((*batch_shape, length), pad_index, dtype=self.tokens.dtype, device=device),
            new_lengths,
        )
        dst = out._mask()
        src = self._mask(start, stop, clampped=False)
        out.tokens[dst] = self.tokens[src]
        return out
    
    def select(self, mask_or_index) -> Self:
        return self.apply_transform(lambda x: x[mask_or_index])

    def append_(self, tokens: Tensor | int, cond: Tensor | None = None):
        t = self.lengths
        # add only one token
        if isinstance(tokens, int) or tokens.ndim == 0:
            mask = self.pos_idx == t.unsqueeze(-1)  # [*shape, maxlen]
            if cond is not None:
                mask &= cond.unsqueeze(-1)
            self.tokens.masked_fill_(mask, tokens)
            if cond is None:
                self.lengths.add_(1).clamp_max_(self.max_length)
            else:
                self.lengths = torch.where(
                    cond, self.lengths + 1, self.lengths
                ).clamp_max_(self.max_length)
        # add a sequence
        else:
            if tokens.ndim > 1:
                raise ValueError("Cannot append multi-dimensional tensors.")
            len_ = tokens.size(-1)
            for i in range(len_):
                mask = self.pos_idx == (t + i).unsqueeze(-1)
                if cond is not None:
                    mask &= cond.unsqueeze(-1)
                self.tokens.masked_fill_(mask, tokens[i])
            if cond is None:
                self.lengths.add_(len_).clamp_max_(self.max_length)
            else:
                self.lengths = torch.where(
                    cond, self.lengths + len_, self.lengths
                ).clamp_max_(self.max_length)
    
    def append_from_(
        self,
        other: TokenBuffer,
        start: int | Tensor = 0,
        stop: int | Tensor | None = None,
        clammped: bool = True,
        cond: Tensor | None = None,
    ):

        if stop is None:
            stop = other.lengths
        elif clammped:
            stop = other.lengths.clamp_max(stop)
        if cond is not None:
            stop = torch.where(cond, stop, 0)
        if isinstance(start, int) and start == 0:
            srclen = stop
        else:
            srclen = stop - start
        if other.shape != self.shape:
            other = other.expand(*self.shape)

        next_len = (self.lengths + srclen).clamp_max(self.max_length)
        append_len = next_len - self.lengths
        dest = self._mask(self.lengths, next_len, clampped=False)
        src = other._mask(start, start + append_len, clampped=False)
        self.tokens[dest] = other.tokens[src]
        self.lengths.copy_(next_len)

    @staticmethod
    def empty(shape: Sequence[int], length: int, device=None, pad_index: int = 0):
        return TokenBuffer(
            torch.full((*shape, length), pad_index, device=device, dtype=torch.int32),
            torch.zeros(shape, device=device, dtype=torch.int64),
        )
    
    def __get_start_pos(self, since: int | Tensor | None, default: int = 0):
        if since is None:
            since = default
        if isinstance(since, int):
            start = since
        else:
            start = int(since.min())
        assert start >= 0
        return since, start
    
    def __get_stop_pos(self, until: int | Tensor | None, clamped: bool):
        if until is None:
            until = self.lengths
        elif clamped:
            until = self.lengths.clamp_max(until)
        stop = until if isinstance(until, int) else int(until.max())
        return until, stop
    
    def contains(
        self,
        seqs: Seqs | Tensor,
        since: int | Tensor | None = None,
        until: int | Tensor | None = None,
        old: Tensor | None = None,
        clamped: bool = True,
        mode: Literal["any", "all", "respective"] = "any",
    ) -> Tensor:
        
        since, start = self.__get_start_pos(since)
        until, stop = self.__get_stop_pos(until, clamped)
    
        _range = slice(start, stop)
        tokens = self.tokens[..., _range]
        mask = self._mask(since, until, _range=_range)  # (*shape, len_range)
        match = _th.match_suffix(tokens, seqs)  # (*shape, len_range, n_seqs?)
        if isinstance(seqs, Tensor):
            if seqs.ndim != 1:
                raise IndexError("Only 1-dimensional tensor is allowed")
            match = (match & mask).any(dim=-1)  # (*shape, n_seqs)
        else:
            match = (match & mask.unsqueeze(-1)).any(dim=-2)  # (*shape, n_seqs)
            if mode == 'respective':
                pass
            elif mode == 'all':
                match = match.all(dim=-1)
            elif mode == 'any':
                match = match.any(dim=-1)
            else:
                assert False

        if old is not None:
            return old | match
        else:
            return match
    
    def count(
        self,
        seqs: Seqs | Tensor,
        since: int | Tensor | None = None,
        until: int | Tensor | None = None,
        old: Tensor | None = None,
        clamped: bool = True,
        mode: Literal["any", "respective"] = "any",
    ) -> Tensor:
        
        since, start = self.__get_start_pos(since)
        until, stop = self.__get_stop_pos(until, clamped)
        _range = slice(start, stop)
        tokens = self.tokens[..., _range]
        mask = self._mask(since, until, _range=_range)  # (*shape, len_range)
        match = _th.match_suffix(tokens, seqs)  # (*shape, len_range, n_seqs?)
        if isinstance(seqs, Tensor):
            if seqs.ndim != 1:
                raise IndexError("Only 1-dimensional tensor is allowed")
            cnt = (match & mask).count_nonzero(dim=-1)  # (*shape)
        else:
            match = (match & mask.unsqueeze(-1))  # (*shape, n_seqs)
            if mode == 'respective':
                cnt = match.count_nonzero(dim=-2)
            elif mode == 'any':
                cnt = match.count_nonzero(dim=(-1, -2))
            else:
                assert False
        if old is not None:
            return old + cnt
        else:
            return cnt
    
    def find(
        self, 
        seqs: Tensor | Sequence[Tensor],
        since: int | Tensor | None = None,
        until: int | Tensor | None = None,
        clamped: bool = True,
        which: Literal['first', 'last'] = 'first',
        mode: Literal["any", "respective"] = "any",
        default: int | Tensor = -1,
    ) -> Tensor:
        """
        Find suffixes `(*m, L)` in tokens (*n, T).
        Output a tensor `t` with shape `(*n, *m)`, such that `t := t[*i, *j]` makes sure that
        `tokens[*i, :t]` ends with `suffix[*j]`.
        If `suffix[*j]` does not exist in `tokens[*i]`, `t[*i, *j]` is the default value.
        """
        since, start = self.__get_start_pos(since)
        until, stop = self.__get_stop_pos(until, clamped)
        _range = slice(start, stop)
        tokens = self.tokens[..., _range]
        mask = self._mask(since, until, _range=_range)  # (*shape, len_range)
        match = _th.match_suffix(tokens, seqs)  # (*shape, len_range, n_seqs?)
        bias = start + 1
        if isinstance(seqs, Tensor):
            if seqs.ndim != 1:
                raise IndexError("Only 1-dimensional tensor is allowed")
            m = (match & mask)  # (*shape, len_range)
            idx = _th.find_nonzero(m, -1, which, default, bias)
        else:
            m = (match & mask.unsqueeze(-1))  # (*shape, len_range, n_seqs)
            if mode == 'any':
                m = m.any(dim=-1)  # (*shape, len_range)
                idx = _th.find_nonzero(m, -1, which, default, bias)
            elif mode == 'respective':
                idx = _th.find_nonzero(m, -2, which, default, bias)
            else:
                assert False
        return idx
    
    @staticmethod
    def concat(
        *batches: TokenBuffer,
        min_length: int = 0,
        truncate_length: int | None = None,
        pad_index: int = 0,
    ):
        
        if len(batches) == 0:
            raise ValueError("The argument is empty.")
        
        batch_shape = batches[0].shape
        device = batches[0].device
        lengths = torch.stack([b.lengths for b in batches])
        required_length = lengths.sum(0)
        length = max(int(required_length.max()), min_length)
        if truncate_length:
            length = min(truncate_length, length)

        out = TokenBuffer(
            torch.full((*batch_shape, length), pad_index, dtype=torch.int32, device=device),
            torch.zeros(*batch_shape, dtype=torch.int64, device=device),
        )
        for batch in batches:
            t = out.lengths
            if torch.all(t >= length):
                break
            next_t = (t + batch.lengths).clamp_max(length)
            dest = out._mask(start=t, stop=next_t, clampped=False)
            src = batch._mask(stop=(next_t - t))
            out.tokens[dest] = batch.tokens[src]
            out.lengths.copy_(next_t)
        
        return out
    
    def enumerate(
        self,
        start: int | Tensor = 0,
        stop: int | Tensor | None = None,
        clamped: bool = True
    ):
        if isinstance(start, Tensor):
            start = start.expand(self.shape)
        if isinstance(stop, Tensor):
            stop = stop.expand(self.shape)
        for i in _iter.indices(self.shape):
            start_ = start if isinstance(start, int) else int(start[i])
            stop_ = stop if isinstance(stop, int | None) else int(stop[i])
            yield i, self.tokens_at(*i, start=start_, stop=stop_, clamped=clamped)

    def detokenize(
        self,
        detokenizer: Callable[[Tensor], str],
        start: int | Tensor = 0,
        stop: int | Tensor | None = None,
    ) -> StrArray:
        array = np.zeros(self.shape, dtype=object)
        for idx, tokens in self.enumerate(start, stop):
            array[idx] = detokenizer(tokens)
        return array.astype(str)

    @staticmethod
    def from_sequences(
        tokens: SequencesLike,
        length: Tensor | int | None = None,
        preprocessor: Preprocessor | None = None,
        pad_index: int = 0,
    ):
        
        def _tokenize(seq: Tensor | str):
            if isinstance(seq, str):
                if preprocessor is None:
                    raise ValueError("preprocessor is missing.")
                return preprocessor.encode(seq)
            if seq.ndim != 1:
                raise ValueError("Each element of must be a sequence of tokens (ndim = 1).")
            return seq

        # tokenize the input strings
        if isinstance(tokens, str):
            if preprocessor is None:
                raise ValueError("missing preprocessor.")
            tokens = preprocessor.encode(tokens)
            length = torch.tensor(tokens.size(0), dtype=torch.int64, device=tokens.device)
        elif isinstance(tokens, (list, np.ndarray)):
            if isinstance(tokens, np.ndarray):
                input_token_list = [_tokenize(seq) for seq in tokens.flatten().tolist()]
                input_shape = tokens.shape
            else:
                input_token_list = list(map(_tokenize, tokens))
                input_shape = None
            tokens = pad_and_stack_tensors(input_token_list, pad_index=pad_index)
            length = torch.tensor(
                list(map(lambda seq: seq.size(0), input_token_list)),
                dtype=torch.int64,
                device=tokens.device
            )
            if input_shape is not None:
                tokens = tokens.reshape(*input_shape, tokens.size(-1))
                length = length.reshape(*input_shape)
            del input_token_list
        else:
            tokens = tokens.clone()
        
        device = None if preprocessor is None else preprocessor.device
        tokens = tokens.to(device=device, dtype=torch.int32)
            
        assert isinstance(tokens, Tensor)
        if tokens.ndim == 0:
            raise ValueError("the input must be a sequence or sequences.")

        if length is None:
            length = tokens.size(-1)
        if isinstance(length, Tensor):
            length = length.to(device=tokens.device, dtype=torch.int64)
        else:
            length = torch.tensor(length, dtype=torch.int64, device=tokens.device)
        length = length.expand(tokens.shape[:-1]).clone()
        
        return TokenBuffer(tokens, length)

    def make_tensor(self, elem_shape: tuple[int, ...], dtype: torch.dtype, value: Any = 0):
        return torch.full(self.shape + elem_shape, value, dtype=dtype, device=self.device)

    def make_sequence(self, elem_shape: tuple[int, ...], dtype: torch.dtype, value: Any = 0):
        return torch.full(self.tokens.shape + elem_shape, value, dtype=dtype, device=self.device)
    
    def make_flag(self, value: bool | torch.Tensor = False):
        if isinstance(value, bool):
            return torch.full(self.shape, value, device=self.device, dtype=torch.bool)
        else:
            return value.expand(self.shape).bool().clone()


def pad_and_stack_tensors(tensors: list[Tensor], length: int | None = None, pad_index=0):
    
    max_length = max(tensor.size(0) for tensor in tensors)
    if length is not None and max_length > length:
        raise ValueError(f"One of the inputs ({max_length} tokens) is longer than the desired length ({length}).")
    length = length or max_length
    device = tensors[0].device
    dtype = tensors[0].dtype
    padded_tensor = torch.full((len(tensors), length), pad_index, device=device, dtype=dtype)
    for i, tensor in enumerate(tensors):
        padded_tensor[i, :tensor.size(0)] = tensor.to(device=device, dtype=dtype)
    return padded_tensor
