from copy import deepcopy
import random, re
from nltk import sent_tokenize


class DataProcessor:
    """Given a stream of input text, creates examples for Pretraining."""

    # Usage:
    # Give a huggingface dataset to `process_hfdataset`
    # (Iteratively) give text strings to `__call__`

    # Note:
    # "Paragraph"s are split by double newline (i.e. "\n\n") from a text
    # "Sentence"s are split by nltk.sent_tokenize from a paragraph
    # A "Segment" is one or mulitple consecutive sentences which ends with a [SEP]
    # Tokens overflow the max sequence length will be trimmed and discarded

    def __init__(
        self,
        hf_tokenizer,  # hugginface tokenizer
        max_sequence_length: int,  # max number of tokens in an example, including sentinel tokens (CLS, SEP).
        sentence_length_threshold: int,  # sentence whose number of tokens shorter than the threshold, will be merged to its neighbored sentence if they are in the same paragraph.
    ):

        self.hf_tokenizer = hf_tokenizer
        self.max_sequence_length = max_sequence_length
        self.sentence_length_threshold = sentence_length_threshold

        # Initialize collectors' statuses
        ## these should be carefully reset druing processing
        self._current_sentences = []  # List[List[str]], list of tokenized sentences
        self._current_num_tokens = 0  # number of non-sentinel tokens
        self._current_paragraph_ids = []  # List[int], pargraph ids of sentences
        self._current_target_length = self.max_sequence_length
        self._current_want_middle_sep = True

    def process_hfdataset(self, hf_dataset, text_col="text", **map_kwargs):
        # Wrap Dataset.map to set necessary map args for this data processor
        return hf_dataset.map(
            function=self,
            batched=True,  # receive b examples and be able to return >b / <b processed examples
            remove_columns=hf_dataset.column_names,  # remove all original columns, because they may have different number of rows(examples) with the processed data
            with_indices=True,  # use example index as document index
            input_columns=[
                text_col
            ],  # name of data column used in processing function `__call__`
            **map_kwargs,
        )

    # ==============================
    # Collect
    # ==============================

    # Note:
    # batch size here is #examples loaded to be preprocessed a time, not model training batch size.
    # returned batch size is often larger than batch size.
    def __call__(
        self, texts: list[str], document_ids: list[int],  # (batch size)  # (batch size)
    ):
        new_examples = {
            "input_ids": [],  # List[List[int]], (returned batch size, #tokens)
            "segmentA_length": [],  # List[int], (returned batch size)
            "document_id": [],  # List[int], (returned batch size)
            "extended_paragraph_ids": [],  # List[List[int]], (returned batch size, #sentences)
            "extended_sentence_lengths": [],  # List[List[int]], (returned batch size, #sentences)
        }

        for document_id, text in zip(document_ids, texts):  # for every doc

            paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
            for paragraph_id, paragraph in enumerate(paragraphs):  # for every paragraph

                sentences = self._sentencize(paragraph)
                for sentence in sentences:  # for every sentence
                    tokens = self.hf_tokenizer.tokenize(sentence)
                    if not tokens:
                        # tokenizer return empty when characters are all unkown
                        continue
                    self._append_sentence(tokens, paragraph_id)
                    num_sentinels = 2 + int(self._current_want_middle_sep)
                    current_sequence_length = self._current_num_tokens + num_sentinels
                    if current_sequence_length >= self._current_target_length:
                        self._generate_and_add_example(new_examples, document_id)

            # at the end of a document
            if self._current_num_tokens != 0:
                self._generate_and_add_example(new_examples, document_id)

        return new_examples

    def _sentencize(self, text: str) -> list[str]:
        sentences = []
        for _sentence in sent_tokenize(text):
            # split by '."' but remain it
            sents = re.sub('(\.\s*")', "\g<1>%%", _sentence).split("%%")
            sentences.extend(sents)
        return sentences

    def _append_sentence(self, tokens: list[str], paragraph_id: int):
        self._current_sentences.append(tokens)
        self._current_num_tokens += len(tokens)
        self._current_paragraph_ids.append(paragraph_id)

    # ==============================
    # Create data instance
    # ==============================

    def _generate_and_add_example(self, new_examples: dict, document_id: int):
        """Creates a pre-training example from the current list of sentences."""

        # 1. Merging
        # too short sentence will be merged to its neighbored sentence if they are in the same paragraph
        (
            sentences,  # list[list[str]], list of tokenized sentences after merging
            paragraph_ids,  # list[int], list of pargraph ids of sentences
        ) = self._merge_short_sentence(
            sentences=self._current_sentences, paragraph_ids=self._current_paragraph_ids
        )

        # 2. Concatenation and inserting sentinel tokens
        (
            sequence,  # list[str]
            segmentA_length,  # int, number of tokens from [CLS] to first [SEP] (included)
            extended_paragraph_ids,  # list[int], list of pargraph ids of sentinels and sentences, where at sentinel it is marked -1
            extended_sentence_lengths,  # list[int], list of sentence lengths, where at sentinel it is marked -1
        ) = self._generate_input_sequence(
            sentences=sentences, paragraph_ids=paragraph_ids
        )
        ## if the last sentence is too short after truncation, and is mergeable
        num_merged_sentences = len(sentences)
        if (
            extended_sentence_lengths[-2] < self.sentence_length_threshold
            and extended_paragraph_ids[-2] == extended_paragraph_ids[-3]
        ):
            extended_paragraph_ids.pop(-2)
            last_sent_len = extended_sentence_lengths.pop(-2)
            extended_sentence_lengths[-2] += last_sent_len
            num_merged_sentences -= 1

        # 3. Check
        try:
            num_sentinels = int(segmentA_length != len(sequence)) + 2
            if self._current_num_tokens + num_sentinels >= self._current_target_length:
                # number of collected tokens can satisfy target length, sequence length should equal to target length
                assert len(sequence) == self._current_target_length
            else:
                # number of collected tokens can't satisfy target length, so truncation shouldn't be triggered
                len(sequence) == (sum(len(s) for s in sentences) + num_sentinels)
            assert sequence[segmentA_length - 1] == self.hf_tokenizer.sep_token
            assert sequence[1] != self.hf_tokenizer.sep_token
            assert sequence[-2] != self.hf_tokenizer.sep_token
            # Check extended sentence lengths
            assert (
                len(extended_sentence_lengths)
                == len(extended_paragraph_ids)
                == num_merged_sentences + num_sentinels
            )
            _num_minus_one = 0
            _sum_of_lengths = 0
            for i, sent_len in enumerate(extended_sentence_lengths):
                _sum_of_lengths += abs(sent_len)
                if sent_len == -1:
                    assert extended_paragraph_ids[i] == -1
                    _num_minus_one += 1
                    continue
                assert sent_len > 0
                if sent_len < self.sentence_length_threshold:
                    assert (
                        extended_paragraph_ids[i] != extended_paragraph_ids[i - 1]
                        and extended_paragraph_ids[i] != extended_paragraph_ids[i + 1]
                    )
            assert num_sentinels == _num_minus_one
            assert len(sequence) == _sum_of_lengths
        except AssertionError as err:
            print(err, flush=True)
            breakpoint()

        # 4. Generate and add the new data point
        new_example = self._generate_example(
            tokens=sequence,
            segmentA_length=segmentA_length,
            document_id=document_id,
            extended_paragraph_ids=extended_paragraph_ids,
            extended_sentence_lengths=extended_sentence_lengths,
        )
        for k, v in new_example.items():
            new_examples[k].append(v)

        # 5. Reset collectors and statuses
        self._current_sentences = []
        self._current_num_tokens = 0  # [CLS] [SEP] [SEP]
        self._current_paragraph_ids = []
        self._current_want_middle_sep = random.random() > 0.1
        if random.random() < 0.05:
            ## small chance for random-length to adapt better to input with shorter lengths
            self._current_target_length = random.randint(5, self.max_sequence_length)
        else:
            self._current_target_length = self.max_sequence_length

    def _merge_short_sentence(
        self, sentences: list[list[str]], paragraph_ids: list[int]
    ):
        assert len(sentences) == len(paragraph_ids)  # = number of sentences
        old_sentences = deepcopy(sentences)

        new_sentences, new_para_ids = [], []
        for i in range(len(old_sentences)):
            # Get the sentence
            sentence: list[str] = old_sentences[i]
            para_id = paragraph_ids[i]
            sent_len = len(sentence)
            # Ignore sentences that satisfy the threshold
            if sent_len >= self.sentence_length_threshold:
                new_sentences.append(sentence)
                new_para_ids.append(para_id)
                continue
            # Merge this too short sentence into another sentence
            ## merge within pargraph
            prev_mergible = (len(new_sentences) != 0) and (new_para_ids[-1] == para_id)
            next_mergible = i != (len(old_sentences) - 1) and (
                para_id == paragraph_ids[i + 1]
            )
            ## prefer merge to shorter sentence
            if not prev_mergible and not next_mergible:
                new_sentences.append(sentence)
                new_para_ids.append(para_id)
            elif prev_mergible and next_mergible:
                prev_len = len(new_sentences[-1])
                next_len = len(old_sentences[i + 1])
                if prev_len <= next_len:
                    new_sentences[-1] += sentence
                else:
                    old_sentences[i + 1] = sentence + old_sentences[i + 1]
            elif prev_mergible:
                new_sentences[-1] += sentence
            elif next_mergible:
                old_sentences[i + 1] = sentence + old_sentences[i + 1]

        # Check
        assert [t for s in sentences for t in s] == [
            t for s in new_sentences for t in s
        ]
        assert len(new_sentences) == len(new_para_ids)  # = number of (merged) sentences

        return new_sentences, new_para_ids

    def _generate_input_sequence(
        self, sentences: list[list[str]], paragraph_ids: list[int]
    ):
        # Add CLS and first sentence
        sequence = [self.hf_tokenizer.cls_token, *sentences[0]]
        extended_paragraph_ids = [-1, paragraph_ids[0]]
        extended_sentence_lengths = [-1, len(sentences[0])]

        # Concatenate sentences and insert SEP in the middle if needed
        segmentA_length = None
        half_point = self._current_target_length / 2
        for tokens, paragraph_id in zip(sentences[1:], paragraph_ids[1:]):

            # add middle SEP if needed
            if self._current_want_middle_sep and segmentA_length is None:
                behind_half_point = len(sequence) > half_point
                will_behind_half_point = len(sequence) + len(tokens) > half_point
                if behind_half_point or (
                    will_behind_half_point and random.random() < 0.5
                ):
                    ## add middle SEP, s.t. lengths of two segments are close to same the best
                    ## we don't know whether adding this sentence minimize difference between
                    ## length of segment A and length of segment B, so we guess.
                    sequence.append(self.hf_tokenizer.sep_token)
                    extended_paragraph_ids.append(-1)
                    extended_sentence_lengths.append(-1)
                    segmentA_length = len(sequence)

            # append the sentence
            sequence.extend(tokens)
            extended_paragraph_ids.append(paragraph_id)
            extended_sentence_lengths.append(len(tokens))

        # Truncation
        _tgt_len = self._current_target_length - 1  # leave 1 for ending SEP
        excess_length = len(sequence) - _tgt_len
        if excess_length > 0:
            sequence = sequence[:-excess_length]
            extended_sentence_lengths[-1] -= excess_length

        # Append ending SEP
        sequence.append(self.hf_tokenizer.sep_token)
        extended_paragraph_ids.append(-1)
        extended_sentence_lengths.append(-1)
        if segmentA_length is None:
            segmentA_length = len(sequence)

        return (
            sequence,
            segmentA_length,
            extended_paragraph_ids,
            extended_sentence_lengths,
        )

    def _generate_example(
        self,
        tokens,
        segmentA_length,
        document_id,
        extended_paragraph_ids,
        extended_sentence_lengths,
    ):
        input_ids = self.hf_tokenizer.convert_tokens_to_ids(tokens)
        return {
            "input_ids": input_ids,  # List[int], (#tokens)
            "segmentA_length": segmentA_length,  # int
            "document_id": document_id,  # int
            "extended_paragraph_ids": extended_paragraph_ids,  # List[int], (#sentences + #sentinels)
            "extended_sentence_lengths": extended_sentence_lengths,  # List[int], (#sentences + #sentinels)
        }
