import re, logging
from typing import Dict, List, Tuple, Union, Set
import copy

from transformers.tokenization_utils import PreTrainedTokenizer

from .utils import all_subclasses

from .constants import *
from .utils import *
from .states import *

logger = logging.getLogger(LOGGER_NAME)

class Operator:
    def __init__(self): self.content, self.next = None, None

class LoopOperator(Operator):
    def __init__(self, op_type : str):
        super().__init__()
        self.loop_limit = int(op_type.split('_')[-1]) if '_' in op_type else 10000

class Always(LoopOperator):
    NAME = 'always'
    def __init__(self, op_type : str, args : List):
        super().__init__(op_type)
        self.content = args[0]

class Until(LoopOperator):
    NAME = 'until'
    def __init__(self, op_type : str, args : List):
        assert len(args) == 2, f'Unparseable structure'
        assert type(args[1]) == DecodingState, f'Exit condition for {type(self)} must be DecodingState'
        super().__init__(op_type)
        self.content = args[0]
        self.condition = args[1]

class Next(Operator):
    NAME = 'next'
    def __init__(self, args : List):
        super().__init__()
        self.content = args[0]
        if len(args) > 1: self.next = Next(args[1:])

def _make_operator(op_type : str, args : List):
    if op_type.startswith(Always.NAME): return Always(op_type, args)
    elif op_type.startswith(Until.NAME): return Until(op_type, args)
    elif op_type == Next.NAME: return Next(args)
    else: raise ValueError(f'Unknown operator type {op_type}')

###
#
###

class TransitionMonitor:

    """
    The transition monitor executes timesteps. It maintains the current state
    """

    def __init__(self, specification : tuple, tokenizer : PreTrainedTokenizer):
        self.specification = specification
        self.tokenizer = tokenizer
        self._extract_transition_monitor()

    def _extract_transition_monitor(self):

        operator_classes = list(sorted(set([subcls.NAME for subcls in all_subclasses(Operator) if getattr(subcls, 'NAME', False)])))

        def _make_graph(curr_expr):
            if type(curr_expr) == tuple:
                assert curr_expr[0].split('_')[0] in operator_classes and len(curr_expr) > 1, f'Malformed behavior at {curr_expr}'
                proc_args = [_make_graph(arg) for arg in curr_expr[1:]]
                new_operator = _make_operator(curr_expr[0], proc_args)
                self._operators.add(new_operator)
                return new_operator
            else:
                assert curr_expr in self.states, f'Unknown state {curr_expr}'
                return copy.deepcopy(self.states[curr_expr])

        self.states = extract_states(self.specification, self.tokenizer)
        self._operators = set()
        behavior = get_components_type(self.specification, BEHAVIOR_KEY)
        assert len(behavior) == 2, f'{BEHAVIOR_KEY} formatted incorrectly'
        behavior = behavior[1]
        top_level_node = _make_graph(behavior)

        self._stack = [top_level_node]
        self._init_op = self._stack[0]

    def make_copy(self):
        return TransitionMonitor(self.specification, self.tokenizer)

    def exit_reached(self):
        return self._stack == []

    def matches_state(self, input_tokens : List):
        matches = []
        for state in self.states.values():
            matched_seq = state.matches(input_tokens)
            if matched_seq is not None:
                matches.append((state, matched_seq))
        if matches:
            state, matched_seq = max(matches, key=lambda x : len(x[1]))
            return state, matched_seq
        return None, None

    def accept_state(self, proposed_state : DecodingState, from_state : Union[Operator, DecodingState]=None):
        current = self._stack[-1] if from_state is None else from_state
        if isinstance(current, DecodingState):
            return current == proposed_state
        elif isinstance(current, Until):
            if proposed_state == current.condition:
                return True
        return self.accept_state(type(proposed_state), from_state=current.content)

    def get_valid_states(self, from_state : Union[Operator, DecodingState]=None, incl_op=False, valid_states : List=None):
        if from_state is None and self._stack == []: return []
        if valid_states is None: valid_states = []
        current = self._stack[-1] if from_state is None else from_state
        if isinstance(current, DecodingState):
            valid_states.append(current)
        elif isinstance(current, Until) and incl_op:
            valid_states.append(current.condition)

        if isinstance(current, LoopOperator) and incl_op and current.loop_limit <= 0:
            return valid_states

        return self.get_valid_states(from_state=current.content, incl_op=incl_op, valid_states=valid_states) if isinstance(current, Operator) else valid_states

    def step(self, state : DecodingState):
        # gets next element of stack
        current = self._stack.pop()
        if isinstance(current, DecodingState):
            assert current == state, f'Verification of state failed, expected {current.name} but received {state.name}'
            return
        elif isinstance(current, Until):
            if state == current.condition: return
            assert current.loop_limit >= 0, f'Verification of state failed, expected {current.condition.name} but received {state.name}'
        self._push(current)
        self.step(state)

    def _push(self, to_add : Union[Operator, DecodingState]):
        if isinstance(to_add, Next):
            if to_add.next: self._push(to_add.next)
            self._push(to_add.content)
        elif isinstance(to_add, LoopOperator):
            self._stack.append(to_add)
            if to_add.loop_limit >= 0:
                to_add.loop_limit -= 1
                self._stack.append(to_add.content)
        else:
            self._stack.append(to_add)

    def validate_sequence_from_string(self, string : str):
        state_texts = [v.text for v in self.states.values()]
        split_lst = re.split('(' + '|'.join([re.escape(x) for x in state_texts]) + ')', string)
        # skip empty
        if split_lst[0] == '': split_lst = split_lst[1:]

        # the first step should be an initialization
        assert split_lst[0] and split_lst[0] not in state_texts, f'Start of input {split_lst[0]} should not be a state!'

        split_lst = split_lst[1:]

        steps = []
        for entry in split_lst:
            if entry in state_texts: steps.append([])
            steps[-1].append(entry)
        return self.validate_sequence(steps)

    def validate_sequence(self, state_history : List[Tuple[str, str]]):
        monitor = self.make_copy()
        state_dict = { state.text : state for state in monitor.states.values() }
        for operator in monitor._operators:
            if issubclass(type(operator), LoopOperator):
                operator.loop_limit = 10000
        for dec_str, _ in state_history:
            assert not monitor.exit_reached(), f'Exit state reached but there are still state transitions occurring!'
            matched_state = state_dict[dec_str]
            monitor.step(matched_state)
        if not monitor.exit_reached():
            logger.warning(f'Exit state not reached for prompt example, ended at {monitor._stack[-1].name if isinstance(monitor._stack[-1], DecodingState) else type(monitor._stack[-1])}. Example may be malformed!')
        return True

