import dataclasses
from collections.abc import Iterable
from typing import NewType, TypeVar

from .automaton import (
    Symbol,
    State,
    Transition,
    Automaton,
    AutomatonContainer,
    WeightedAutomaton,
    WeightedAutomatonContainer
)
from .reserved import ReservedSymbol
from .semiring import Semiring


StackSymbol = NewType('StackSymbol', int)
Weight = TypeVar('Weight')

@dataclasses.dataclass(frozen=True)
class PushdownAutomatonTransition(Transition):
    popped_symbol: StackSymbol
    pushed_symbols: tuple[StackSymbol, ...]

class PushdownAutomaton(Automaton):

    def transitions(self) -> Iterable[PushdownAutomatonTransition]:
        return super().transitions() # type: ignore

    def stack_alphabet_size(self) -> int:
        raise NotImplementedError

    def initial_stack_symbol(self) -> StackSymbol:
        raise NotImplementedError

    def accept_state(self) -> State:
        raise NotImplementedError

class PushdownAutomatonContainer(PushdownAutomaton, AutomatonContainer):

    _stack_alphabet_size: int
    _initial_stack_symbol: StackSymbol
    _accept_state: State

    def __init__(self,
        *,
        num_states: int=1,
        alphabet_size: int,
        stack_alphabet_size: int=1,
        initial_state: State=State(0),
        initial_stack_symbol: StackSymbol=StackSymbol(0),
        accept_state: State=State(0)
    ):
        super().__init__(
            num_states=num_states,
            alphabet_size=alphabet_size,
            initial_state=initial_state
        )
        if initial_stack_symbol >= stack_alphabet_size:
            raise ValueError
        if accept_state >= num_states:
            raise ValueError
        self._stack_alphabet_size = stack_alphabet_size
        self._initial_state = initial_state
        self._initial_stack_symbol = initial_stack_symbol
        self._accept_state = accept_state

    def add_transition(self, transition: PushdownAutomatonTransition) -> None:
        self._add_transition(transition)

    def stack_alphabet_size(self) -> int:
        return self._stack_alphabet_size

    def initial_stack_symbol(self) -> StackSymbol:
        return self._initial_stack_symbol

    def accept_state(self) -> State:
        return self._accept_state

    def is_accept_state(self, state: State) -> bool:
        return state == self._accept_state

    def set_accept_state(self, state: State) -> None:
        self._accept_state = state

class WeightedPushdownAutomaton(PushdownAutomaton, WeightedAutomaton[Weight]):

    def transition_weights(self) -> Iterable[tuple[PushdownAutomatonTransition, Weight]]:
        return super().transition_weights() # type: ignore

class WeightedPushdownAutomatonContainer(WeightedPushdownAutomaton[Weight], WeightedAutomatonContainer[Weight]):

    _stack_alphabet_size: int
    _initial_stack_symbol: StackSymbol
    _accept_state: State

    def __init__(self,
        *,
        num_states: int=1,
        alphabet_size: int,
        stack_alphabet_size: int=1,
        initial_state: State=State(0),
        initial_stack_symbol: StackSymbol=StackSymbol(0),
        accept_state: State=State(0),
        semiring: Semiring[Weight]
    ):
        super().__init__(
            num_states=num_states,
            alphabet_size=alphabet_size,
            initial_state=initial_state,
            semiring=semiring
        )
        self._stack_alphabet_size = stack_alphabet_size
        self._initial_state = initial_state
        self._initial_stack_symbol = initial_stack_symbol
        self._accept_state = accept_state

    @property
    def deterministic(self):
        return self.is_deterministic()

    def set_transition_weight(self,
        transition: PushdownAutomatonTransition,
        weight: Weight
    ) -> None:
        self._set_transition_weight(transition, weight)

    def stack_alphabet_size(self) -> int:
        return self._stack_alphabet_size

    def initial_stack_symbol(self) -> StackSymbol:
        return self._initial_stack_symbol

    def accept_state(self) -> State:
        return self._accept_state

    def is_accept_state(self, state: State) -> bool:
        return state == self._accept_state

    def set_accept_state(self, state: State) -> None:
        self._accept_state = state

    def is_deterministic(self) -> bool:
        sym_config_pairs = set()
        for t in self.transitions():
            a = t.symbol
            p = t.state_from
            X = t.popped_symbol
            if a != ReservedSymbol.EPSILON:
                if (p, X, a) in sym_config_pairs or (p, X, ReservedSymbol.EPSILON) in sym_config_pairs:
                    return False
                else:
                    sym_config_pairs.add((p, X, a))
            else:
                if any([True if (q == p and X == Y) else False for (q, Y, b) in sym_config_pairs]):
                    return False
                else:
                    sym_config_pairs.add((p, X, a))
        return True

