# LimeTokenization extends HuggingFace's PretrainedTokenizerFast class. 
# Preserves the same behavior, but augments the BatchEncoding object with additional symbols. 
# These annotation symbols, at both the linguistic word and subword level, can be accessed 
# in the same way as standard data dictionary items (e.g., `BatchEncoding.input_ids`). 
# See example.py for usage.

from transformers import PreTrainedTokenizerFast, BatchEncoding
import spacy

UPOS_GLOSSARY = {
    "ADJ": "adjective",
    "ADP": "adposition",
    "ADV": "adverb",
    "AUX": "auxiliary",
    "CONJ": "conjunction",
    "CCONJ": "coordinating conjunction",
    "DET": "determiner",
    "INTJ": "interjection",
    "NOUN": "noun",
    "NUM": "numeral",
    "PART": "particle",
    "PRON": "pronoun",
    "PROPN": "proper noun",
    "PUNCT": "punctuation",
    "SCONJ": "subordinating conjunction",
    "SYM": "symbol",
    "VERB": "verb",
    "X": "other",
    "EOL": "end of line",
    "SPACE": "space",
}


class LimeTokenizer(PreTrainedTokenizerFast):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.nlp = spacy.load("en_core_web_sm")
        tagger = self.nlp.get_pipe("tagger")
        ner_tagger = self.nlp.get_pipe("ner")
        self.SPECIAL_TOKEN_POS_TAG = 'SPECIAL'
        self.upos_tag_vocab = {tag_str : i for i, tag_str in enumerate(UPOS_GLOSSARY)}
        self.upos_tag_vocab[self.SPECIAL_TOKEN_POS_TAG] = len(self.upos_tag_vocab) # add special token 
        self.pos_tag_vocab = {tag_str : i for i, tag_str in enumerate(tagger.labels)}
        self.pos_tag_vocab[self.SPECIAL_TOKEN_POS_TAG] = len(self.pos_tag_vocab) # add special token 
        self.ner_tag_vocab = {tag_str : i for i, tag_str in enumerate(ner_tagger.labels)}
        self.ner_tag_vocab[""] = len(self.ner_tag_vocab) # add none token
        self.ner_tag_vocab[self.SPECIAL_TOKEN_POS_TAG] = len(self.ner_tag_vocab) # add special token
        self.batch_encodings = None


    def _tags_to_ids(self, taglist, group="UPOS"):
        if group == "UPOS":
            vocab = self.upos_tag_vocab
        elif group == "POS":
            vocab = self.pos_tag_vocab
        elif group == "NER":
            vocab = self.ner_tag_vocab
        else:
            return []
        if len(taglist) > 0:
            if isinstance(taglist[0], list):
                return [self._tags_to_ids(l, group=group) for l in taglist]
            else:
                return [vocab[s] for s in taglist]
        else:
            return []


    def _seq_to_ling_word_tags(self, input_sequence):
        doc = self.nlp(input_sequence)
        upos_tag_sequence = []
        pos_tag_sequence = []
        ner_tag_sequence = []
        pos_word_sequence = []
        trailing_pos_word_sequence = []
        for word in doc:
            upos_tag_sequence.append(word.pos_)
            pos_tag_sequence.append(word.tag_)
            ner_tag_sequence.append(word.ent_type_)
            trailing_pos_word_sequence.append(word.text_with_ws) # keep whitespaces

        pos_word_sequence = self._move_whitespaces(trailing_pos_word_sequence)
        pop_idxs = []
        for i, s in enumerate(pos_word_sequence):
            if s == "": # pop empty elements
                pop_idxs.append(i)
        for i in pop_idxs:
            upos_tag_sequence.pop(i)
            pos_word_sequence.pop(i)
            pos_tag_sequence.pop(i)
            ner_tag_sequence.pop(i)
        return pos_word_sequence, upos_tag_sequence, pos_tag_sequence, ner_tag_sequence


    def _move_whitespaces(self, word_list):
        pos_word_sequence = []
        prefix = ""
        for p in word_list:
            if p[-1:] == " ":
                pos_word_sequence.append(prefix + p[:-1]) # append without trailing ws
                prefix = " " # indicate the next word has leading ws
            else:
                pos_word_sequence.append(prefix + p)
                prefix = ""
        # fix whitespace for last element
        if prefix == " ": 
            pos_word_sequence[-1] = pos_word_sequence[-1] + " "
        assert len(word_list) == len(pos_word_sequence)
        return pos_word_sequence


    # overwrites PreTrainedTokenizerFast._batch_encode_plus in transformers/tokenization_utils_fast.py
    def _batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
        ling_words = []
        upos_word_tags = []
        pos_word_tags = []
        ner_word_tags = []
        upos_token_tags = []
        pos_token_tags = []
        ner_token_tags = []
        solo_str = False

        #  "A single sequence",
        # ("A tuple with a sequence", "And its pair")
        if isinstance(batch_text_or_text_pairs[0], str):
            if len(batch_text_or_text_pairs) == 1:
                solo_str = True
            for b in batch_text_or_text_pairs: # for inputsequence
                ling_words_batch, upos_word_tags_batch, pos_word_tags_batch, ner_word_tags_batch = self._seq_to_ling_word_tags(b)
                ling_words.append(ling_words_batch)
                upos_word_tags.append(upos_word_tags_batch)
                pos_word_tags.append(pos_word_tags_batch)
                ner_word_tags.append(ner_word_tags_batch)
        elif isinstance(batch_text_or_text_pairs[0], tuple): # for squence pair
            for sen1, sen2 in batch_text_or_text_pairs:
                sen1_ling_words_batch, sen1_upos_word_tags_batch, sen1_pos_word_tags_batch, sen1_ner_word_tags_batch = self._seq_to_ling_word_tags(sen1)
                sen2_ling_words_batch, sen2_upos_word_tags_batch, sen2_pos_word_tags_batch, sen2_ner_word_tags_batch = self._seq_to_ling_word_tags(sen2)
                ling_words.append((sen1_ling_words_batch, sen2_ling_words_batch)) # keep tuple for '_batch_encode_plus'
                upos_word_tags.append(sen1_upos_word_tags_batch + sen2_upos_word_tags_batch) # merge to list for correct iteration
                pos_word_tags.append(sen1_pos_word_tags_batch + sen2_pos_word_tags_batch) # merge to list for correct iteration
                ner_word_tags.append(sen1_ner_word_tags_batch + sen2_ner_word_tags_batch) # merge to list for correct iteration

        else:
            raise NotImplementedError("danger. abort.")
                
        #print("pos encoding inject")
        number_of_batches = len(batch_text_or_text_pairs)
        kwargs["is_split_into_words"] = True
        batched_encodings = super()._batch_encode_plus(batch_text_or_text_pairs=ling_words, *args, **kwargs)
        for batch_idx in range(number_of_batches): 
            upos_token_tags_batch = []
            pos_token_tags_batch = []
            ner_token_tags_batch = []
            for _, word_id in enumerate(batched_encodings.word_ids(batch_idx)):
                if word_id is None:
                    upos_token_tags_batch.append(self.SPECIAL_TOKEN_POS_TAG)
                    pos_token_tags_batch.append(self.SPECIAL_TOKEN_POS_TAG)
                    ner_token_tags_batch.append(self.SPECIAL_TOKEN_POS_TAG)

                else:
                    upos_token_tags_batch.append(upos_word_tags[batch_idx][word_id])
                    pos_token_tags_batch.append(pos_word_tags[batch_idx][word_id])
                    ner_token_tags_batch.append(ner_word_tags[batch_idx][word_id])
            upos_token_tags.append(upos_token_tags_batch)
            pos_token_tags.append(pos_token_tags_batch)
            ner_token_tags.append(ner_token_tags_batch)
        
            # extend tokenizers.Encoding data dict with annotation symbols
            data_dict = {key: value for key, value in batched_encodings.items()}
            data_dict["ling_words"] = ling_words
            data_dict["ling_words_upos_tags"] = upos_word_tags
            data_dict["ling_words_pos_tags"] = pos_word_tags
            data_dict["ling_words_ner_tags"] = ner_word_tags
            data_dict["tokens_upos_tags"] = upos_token_tags
            data_dict["tokens_upos_tags_ids"] = self._tags_to_ids(upos_token_tags, group="UPOS") # resolve to ids
            data_dict["tokens_pos_tags"] = pos_token_tags
            data_dict["tokens_pos_tags_ids"] = self._tags_to_ids(pos_token_tags, group="POS") # resolve to ids
            data_dict["tokens_ner_tags"] = ner_token_tags
            data_dict["tokens_ner_tags_ids"] = self._tags_to_ids(ner_token_tags, group="NER") # resolve to ids

            batched_encodings = BatchEncoding(
                data_dict, 
                batched_encodings.encodings,
            )

        # if only one string, remove leading batch axis to align with baseclass behaviour
        if solo_str:
            batched_encodings = BatchEncoding(
                {
                    key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
                    for key, value in batched_encodings.items()
                },
                batched_encodings.encodings,
            )
        self.batch_encodings = batched_encodings # set as tokenizer state to avoid tagging twice for string-only case
        return self.batch_encodings


    # main method to tokenize and prepare for the model 
    def __call__(self, *args, **kwargs):
        super().__call__(*args, **kwargs) 
        return self.batch_encodings # call above must finish first


    # converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
    def encode(self, text, text_pair=None, *args, **kwargs):
        if not isinstance(text, str):
            raise TypeError("only strings supported for text")
        if text_pair and not isinstance(text_pair, str):
            raise TypeError("only strings supported for text_pair")
        batched_output =  super().encode(text=text, text_pair=text_pair, *args, **kwargs)
        return batched_output
    

    # converts a sequence of ids in a string. Can be obtained using the encode method and also __call__
    def decode(self, *args, **kwargs):
        return super().decode(*args, **kwargs)


    # convert a list of lists of token ids into a list of strings. Can be obtained using the __call__ method
    def batch_decode(self, *args, **kwargs):
        return super().batch_decode(*args, **kwargs)
