import torch
import itertools
import dataclasses as dc
from typing import Sequence, Any
from contextlib import contextmanager
from core.utils.th import TensorDataClass
from core.inference import Context


type _Number = int | bool | float


@dc.dataclass
class ContextStack[KEvent, KInfo, KData](TensorDataClass):

    type _Context = Context[KEvent, KInfo, KData]

    buf: _Context
    depths: torch.Tensor
    default_token: int | None = 0
    default_data: dict[KData, _Number] = dc.field(default_factory=dict)
    default_info: dict[KInfo, _Number] = dc.field(default_factory=dict)
    default_event: int | None = -1

    def __post_init__(self):
        assert self.depths.shape == self.buf.shape[:-1]

    @property
    def shape(self):
        return self.depths.shape
    
    @property
    def ndim(self):
        return self.depths.ndim
    
    @property
    def max_depth(self):
        return self.buf.shape[-1] - 1
    
    @classmethod
    def empty(
        cls,
        shape: Sequence[int],
        max_length: int,
        max_depth: int,
        device: torch.device | None = None,
        default_token: int | None = 0,
        default_data: dict[KData, _Number] = {},
        default_info: dict[KInfo, _Number] = {},
        default_event: int | None = -1,
    ):
        return cls(
            Context.empty((*shape, max_depth + 1), max_length,
                          device=device, pad_index=(default_token or 0)),
            torch.zeros(shape, dtype=torch.int64, device=device),
            default_token=default_token,
            default_data=default_data,
            default_info=default_info,
            default_event=default_event,
        )

    def __check_index(
        self,
        idx: torch.Tensor,
        mask: torch.Tensor | None,
    ):
        H = self.max_depth
        if mask is not None:
            idx = torch.where(mask, idx, H)
        if torch.any((idx >  H) | (idx < 0)):
            raise IndexError("The index is out of range")
        if idx.shape != self.shape:
            idx = idx.expand(self.shape)
        return idx
    
    def __clear_context(self, ctx: _Context, index: Any):

        default_token = self.default_token
        default_event = self.default_event
        default_info = self.default_info
        default_data = self.default_data
    
        ctx.lengths[index] = 0
        if default_token is not None:
            ctx.tokens[index] = default_token
        if default_event is not None:
            to_remove: set[KEvent] = set()
            for e, when in ctx.events.items():
                when[index] = default_event
                if torch.all(when < 0):
                    to_remove.add(e)
            for e in to_remove:
                ctx.events.pop(e)
        for k, tensor in ctx.info.items():
            if (default := default_info.get(k)) is not None:
                tensor[index] = default
        for k, tensor in ctx.data.items():
            if (default := default_data.get(k)) is not None:
                tensor[index] = default
    
    def __clear_default(self, mask: torch.Tensor | None = None):
        index: tuple[slice | int | torch.Tensor , ...]
        if mask is None:
            index = (slice(None),) * self.ndim + (self.max_depth,)
        else:
            assert mask.shape == self.shape
            index = (mask, self.max_depth)
        self.__clear_context(self.buf, index)
    
    def _get(self, idx: torch.Tensor, mask: torch.Tensor | None = None, device=None):

        dim = self.ndim
        shape = self.shape
        H = self.max_depth
        idx = self.__check_index(idx, mask)

        def _gather(x: torch.Tensor):
            assert x.shape[:dim] == shape and x.shape[dim] == H + 1
            _idx = idx.reshape(idx.shape + (1,) * (x.ndim - dim)).\
                expand(idx.shape + (1,) + x.shape[dim+1:])
            y = torch.gather(x, dim, _idx).squeeze(dim)
            return y.to(device=device)

        out = self.buf.apply_transform(_gather)
        assert out.shape == shape
        return out
    
    def _clear_invalid(self):
        H = self.buf.shape[-1]
        invalid = (torch.arange(H, device=self.depths.device) >= self.depths.unsqueeze(-1))
        self.__clear_context(self.buf, invalid)
    
    def _set(self, idx: torch.Tensor, ctx: _Context, mask: torch.Tensor | None):
        
        dim = self.ndim
        shape = self.shape
        H = self.max_depth

        idx = self.__check_index(idx, mask).to(device=self.buf.device)

        def set_tensor(dst: torch.Tensor, src: torch.Tensor):
            assert dst.shape[:dim] == src.shape[:dim] == shape
            assert dst.shape[dim:] == (H + 1,) + src.shape[dim:]
            src = src.to(device=dst.device)
            idx_ = idx.reshape(idx.shape + (1,) * (dst.ndim - dim)).\
                expand(idx.shape + (1,) + src.shape[dim:]).to(device=dst.device)
            # src: (*shape, *elem_shape)
            # dst: (*sahpe, H, *elem_shape)
            dst.scatter_(dim, idx_, src.unsqueeze(dim))
        
        # update tokens
        buf = self.buf
        set_tensor(buf.tokens, ctx.tokens)
        set_tensor(buf.lengths, ctx.lengths)

        # update events
        for event, when in ctx.events.items():
            when_buf = buf.when(event).clone()
            set_tensor(when_buf, when)
            buf.events[event] = when_buf
        
        # update data
        for k, v in ctx.data.items():
            try:
                dst = buf.data[k]
            except KeyError:
                buf.data[k] = dst = buf.make_tensor(v.shape[dim:], v.dtype)
            set_tensor(dst, v)

        # update info
        for k, v in ctx.info.items():
            try:
                dst = buf.info[k]
            except KeyError:
                continue
            set_tensor(dst, v)
        
        if torch.any(clear_mask := (idx == H)):
            self.__clear_default(clear_mask)

    def get_top(self, mask: torch.Tensor | None = None, device=None):
        valid = self.depths > 0
        if mask is not None:
            valid = valid & mask
        return self._get(self.depths - 1, valid, device)
    
    def set_top(self, ctx: _Context, mask: torch.Tensor | None = None):
        valid = self.depths > 0
        if mask is not None:
            valid = valid & mask
        self._set(self.depths - 1, ctx, valid)
    
    @contextmanager
    def top(self, mask: torch.Tensor | None = None, device=None):
        top = self.get_top(mask, device)
        if __debug__:
            depths = self.depths.clone()
        try:
            yield top
        finally:
            if __debug__:
                if torch.any(self.depths != depths):
                    raise RuntimeError("Inconsistent depth.")
            self.set_top(top, mask)
    
    def grows(self, mask: torch.Tensor | None = None, device=None):
        self.push(mask)
        return self.top(mask, device)

    def pop(self, mask: torch.Tensor | None, empty_ok: bool = False):
        if mask is None:
            depths = self.depths - 1
        else:
            depths = torch.where(mask, self.depths - 1, self.depths)
        if empty_ok:
            depths = depths.clamp_min(0)
        elif __debug__ and torch.any(depths < 0):
            raise IndexError("Cannot pop from empty stack.")
        self.depths = depths

    def push(self, mask: torch.Tensor | None, full_ok: bool = False):
        if mask is None:
            depths = self.depths + 1
        else:
            depths = torch.where(mask, self.depths + 1, self.depths)
        if full_ok:
            depths = depths.clamp_max(self.max_depth - 1)
        elif __debug__ and torch.any(depths >= self.max_depth):
            raise IndexError("Cannot push a full stack.")
        self.depths = depths
    
    def enumerate(self):
        for idx in itertools.product(*(range(n) for n in self.shape)):
            depth = int(self.depths[idx])
            for d in range(depth):
                idx_ = idx + (d,)
                yield idx_, self.buf.tokens_at(*idx_)