import torch

from transformers.tokenization_utils import PreTrainedTokenizer

from .constants import *
from .utils import *

###
#
###

class DecodingState:

    _TEXT_KEY, _FLAGS_KEY, _CONSTRAINTS_KEY = ':text', ':flags', ':output-constrs'
    _ENV_INPUT_KEY, _BOOL_KEY = ':env-input', ':is-bool'

    def __init__(self, state_desc : tuple, tokenizer : PreTrainedTokenizer):
        self.name, self.tokenizer = state_desc[0], tokenizer        
        text = None
        constraints = None
        self.is_bool = False
        self.defer_to_env = False
        for component in state_desc[1:]:
            if component[0] == DecodingState._TEXT_KEY:
                text = component[1]
            elif component[0] == DecodingState._CONSTRAINTS_KEY:
                constraints = component[1:]
            elif component[0] == DecodingState._FLAGS_KEY:
                for flag in component[1:]:
                    if flag == DecodingState._ENV_INPUT_KEY:
                        self.defer_to_env = True
                    elif flag == DecodingState._BOOL_KEY:
                        self.is_bool = True

        assert text is not None, f'Must specify :text for state {state_desc}'
        self.text = text.replace('\"', '').replace('\'', '')

        assert not (self.defer_to_env and constraints), 'Cannot enforce output constraints while also being an environment action'

        constraints = [constr.replace('\"', '').replace('\'', '') for constr in constraints] if constraints is not None else []
        self.add_constraints(constraints)

        self.tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize('\n' + self.text.strip()))

        # we separate between what we want to add when we match the criteria (the text) and the list of accepting states
        self._check_tokens = []
        self._check_tokens.append(self.tokens)
        for addit in ['', ' ', '\t']:
            self._check_tokens.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(addit + self.text.strip())))

    def add_constraints(self, constraints : List[str]):
        self._constraint_strings = constraints
        self.constraints = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' ' + constr)) for constr in constraints]

    def remove_constraints(self, constraints_to_remove : List[str]):
        self._constraint_strings = [constraint_string for constraint_string in self._constraint_strings if constraint_string not in constraints_to_remove]
        self.constraints = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' ' + constr)) for constr in self._constraint_strings]

    def matches(self, input_tokens : torch.Tensor):
        return seq_matches_any(self._check_tokens, input_tokens)

    def __eq__(self, other):
        return type(other) == DecodingState and self.text == other.text

    def __repr__(self): return self.name

def extract_states(specification : tuple, tokenizer : PreTrainedTokenizer):
    states_lst = get_components_type(specification, STATES_KEY)[1:]
    states = [DecodingState(state, tokenizer) for state in states_lst]
    states_dict = dict()
    for state in states:
        assert state.name not in states_dict, f'Cannot have duplicate state specifications: {state.name}'
        states_dict[state.name] = state
    return states_dict