import torch
import dataclasses as dc
from core.utils.th import ObjectTensor
from core.utils.kv import KVTensors
from typing import Sequence, Callable, overload
from math import prod


def check_conflict(**conditions: bool):
    positive = [k for k, c in conditions.items() if c]
    if len(positive) > 1:
        raise ValueError(f"Arguments conflict: {positive}")


def check_validity(*conditions: bool, msg: str | None = None):
    if not all(conditions):
        raise ValueError(msg)


def implies(a, b):
    return (not bool(a)) or bool(b)
    

@dc.dataclass
class IndividualContext[Event, Info, Data]:

    tokens: torch.Tensor
    events: dict[Event, torch.Tensor]
    info: dict[Info, torch.Tensor]
    data: dict[Data, torch.Tensor]
    kvs: tuple[list[tuple[torch.Tensor, torch.Tensor]], ...]

    def __post_init__(self):
        assert self.tokens.ndim == 1

    
class ContextStack[Event, Info, Data](
    ObjectTensor[list[IndividualContext[Event, Info, Data]]]
):

    @staticmethod
    def empty(shape: Sequence[int], indices_device=None):
        indices = list(range(prod(shape)))
        data: tuple[list[IndividualContext], ...] = tuple([] for _ in indices)
        indices = torch.tensor(indices, device=indices_device).reshape(shape)
        return ContextStack(indices, data)
    
    def clear(self):
        for stack in self._objects:
            stack.clear()

    def top(self, idx):
        stack = self[idx]
        try:
            return stack[-1]
        except IndexError:
            return None
    
    def pop(self, idx):
        stack = self[idx]
        try:
            return stack.pop(-1)
        except IndexError:
            return None
    
    def push(self, idx, ctx: IndividualContext):
        stack = self[idx]
        stack.append(ctx)
