import random
from functools import partial
import torch
from transformers import AutoTokenizer

from _utils import pad_seq
from _abstract_task.data import HFDataModule
from Pretraining.preprocessing import DataProcessor

# Don't have to worry about warning, Token indices sequence length is longer than the specified maximum sequence length for this model (642 > 512). Running this sequence through the model will result in indexing errors
# b/c we'll apply rules to truncate the tokenized sentence
class PretrainingHFDataModule(HFDataModule):
    def __init__(self, config):
        self.config = config
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)

        # Initiate the hf data module
        super().__init__(
            dataset=config.datasets,
            # Template for name of cache file of preprocessed dataset. (Would be placed under the same directory of source dataset arrow file)
            cache_name_template=f"preprocessed_maxlen={config.max_sequence_length}-{{split}}.arrow",
            cache_dir=config.datasets_cache_dir,
            tfm_dataset_cls=partial(
                CorruptionDataset, task_config=config, tokenizer=tokenizer,
            ),
        )

    def _preprocess(self, dataset, split, cache_file_path):
        DataProcessor(
            hf_tokenizer=self.tokenizer,
            max_sequence_length=self.config.max_sequence_length,
            sentence_length_threshold=self.config.sent_shuffle_min_sent_len,
        ).process_hfdataset(
            hf_dataset=dataset,
            cache_file_name=cache_file_path,
            num_proc=self.config.datasets_num_proc,
        )

    def _collate(self, samples: list[dict]):
        b = {key: [sample[key] for sample in samples] for key in samples[0]}  # batch
        padding_map = {
            "corrupted_ids": self.tokenizer.pad_token_id,
            "mlm_labels": -100,
            "segment_ids": 0,
            "permutation": self.config.max_sequence_length - 1,
            "sentence_marks": -1,
            "inter_segment_labels": None,
            "paragraph_ids": -1,
            "sentence_lengths": -1,
            "document_ids": -1,
            "sentence_permutation": -1,
        }
        # Iterating through passed items to make sure every item passed is properly batchized,
        # otherwise that returned unbatchized item slows down data loading process.
        for item in b.keys():
            pad = padding_map[item]
            if pad is None:
                b[item] = torch.stack(b[item])
            else:
                b[item] = pad_seq(b[item], pad=pad)
        return b


class CorruptionDataset(torch.utils.data.Dataset):
    "A wrapper of Dataset to apply item transform for preparing data needed by transformer."

    def __init__(self, dset, task_config, tokenizer):
        self.dataset = dset
        self.config = task_config
        self.tokenizer = tokenizer
        self.special_token_ids = set(
            [
                getattr(tokenizer, f"{name}_token_id")
                for name in ["pad", "bos", "eos", "unk", "cls", "sep", "mask"]
                if getattr(tokenizer, f"{name}_token_id") is not None
            ]
        )
        self.nonspecial_vocab = torch.tensor(
            [i for i in range(tokenizer.vocab_size) if i not in self.special_token_ids]
        )  # nonspecial_vocab[id in nonspecial vocab] = id in original vocab

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        entry = self.dataset[i]
        sample = {}

        if self.config.inter_segment_task:
            (
                sample["inter_segment_labels"],
                entry["input_ids"],
                entry["segmentA_length"],
                token_permutation,
            ) = self.manipulate_segments(
                ids=entry["input_ids"].tolist(),
                segmentA_length=entry["segmentA_length"],
                dataset=self.dataset,
                apply_permutation_immediately=not self.config.use_electra,
                extended_sentence_lengths=entry["extended_sentence_lengths"],
                document_id=entry["document_id"],
            )
            if self.config.use_electra and self.config.inter_segment_task in [
                "sop",
                "sso",
            ]:
                sample["permutation"] = token_permutation

        # Corrupt input sequence
        sample["corrupted_ids"], sample["mlm_labels"] = self.make_mlm_inputs(
            ids=entry["input_ids"], segmentA_length=entry["segmentA_length"],
        )

        # Shuffle sentences for text structure prediction task
        if self.config.tsp_loss_weight:
            # after the token corruption b/c we don't want unit_marks out of order (think if there is a unit cross the sentences)
            (
                token_permutation,
                # <int>(L), permute tokens to permute sentences
                # Given permutation[i] = j, moving j th token to position i
                sample["sentence_marks"],
                # <int>(L), index of sentence the token belongs to against unshuffled sentences
                new_segmentA_length,
                # int, resplit into two segments accoridng to shuffled sentences
                sentence_permutation,
                # <int>(S), sentence-level permutation
                sentence_lengths,
                # <int>(S), length of the sentence in order of unshuffled sentences
                sample["paragraph_ids"],
                # <int>(S), id of pargraph which sentence belongs to in order of unshuffled sentences
            ) = self.shuffle_sentences(
                sequence_length=len(entry["input_ids"]),
                segmentA_length=entry["segmentA_length"],
                extended_sentence_lengths=entry["extended_sentence_lengths"],
                extended_paragraph_ids=entry["extended_paragraph_ids"],
            )
            if self.config.use_electra:
                sample["permutation"] = token_permutation
            else:
                sample["corrupted_ids"] = sample["corrupted_ids"][token_permutation]
                sample["mlm_labels"] = sample["mlm_labels"][token_permutation]
                entry["segmentA_length"] = new_segmentA_length

        # Create segment ids
        sample["segment_ids"] = torch.tensor(
            [0] * entry["segmentA_length"]
            + [1] * (len(sample["corrupted_ids"]) - entry["segmentA_length"])
        )

        return sample

    def make_mlm_inputs(self, ids, segmentA_length):
        is_selectable = torch.full((len(ids),), True)
        is_selectable[0] = is_selectable[-1] = False  # CLS and ending SEP
        is_selectable[segmentA_length - 1] = False  # middle/ending SEP

        corrupted_ids = ids.clone()
        is_selected = torch.zeros(len(ids), dtype=torch.bool)
        remained_prob = 1.0

        if self.config.mask_probability:
            to_mask = torch.rand(len(ids)) < self.config.mask_probability
            to_mask &= is_selectable
            corrupted_ids[to_mask] = self.tokenizer.mask_token_id
            is_selected[to_mask] = True
            is_selectable &= ~to_mask
            remained_prob -= self.config.mask_probability

        if self.config.replace_probability:
            replace_prob = self.config.replace_probability / remained_prob
            to_replace = torch.rand(len(ids)) < replace_prob
            to_replace &= is_selectable
            num_replace = to_replace.sum()
            _random_ids = torch.randint(len(self.nonspecial_vocab), (num_replace,))
            random_ids = self.nonspecial_vocab[_random_ids]
            corrupted_ids[to_replace] = random_ids
            is_selected[to_replace] = True
            is_selectable &= ~to_replace
            remained_prob -= self.config.replace_probability

        if self.config.original_probability:
            stay_prob = self.config.original_probability / remained_prob
            to_stay = torch.rand(len(ids)) < stay_prob
            is_selected[to_stay] = True

        mlm_labels = ids.masked_fill(~is_selected, -100)

        return corrupted_ids, mlm_labels

    def merge_sentences(
        self,
        extended_sentence_lengths,  # <int>(#SENT), #SENT = #sentences + #sentinel tokens
        extended_paragraph_ids,  # <int>(#SENT)
        # Note: what extended_xxx looks like:
        # [-1, xxx of sent1, xxx of sent2, ..., -1, ..., xxx of sentN, -1]
    ):
        sent_lens = extended_sentence_lengths.tolist()
        para_ids = extended_paragraph_ids.tolist()
        new_sent_lens, new_para_ids = [], []
        for i in range(len(sent_lens)):
            sent_len, para_id = sent_lens[i], para_ids[i]
            if sent_len >= 5 or sent_len == -1:
                new_sent_lens.append(sent_len)
                new_para_ids.append(para_id)
                continue

            front_mergible = new_sent_lens and new_sent_lens[-1] != -1
            behind_mergible = i != (len(sent_lens) - 1) and sent_lens[i + 1] != -1
            if front_mergible and behind_mergible:
                if new_sent_lens[-1] < sent_lens[i + 1]:
                    new_sent_lens[-1] += sent_len
                else:
                    sent_lens[i + 1] += sent_len
            elif front_mergible:
                new_sent_lens[-1] += sent_len
            elif behind_mergible:
                sent_lens[i + 1] += sent_len
            else:
                new_sent_lens.append(sent_len)
                new_para_ids.append(para_id)
        new_ext_sentence_lengths = torch.tensor(new_sent_lens)
        new_ext_paragraph_ids = torch.tensor(new_para_ids)

        assert new_ext_sentence_lengths.sum() == extended_sentence_lengths.sum()
        assert len(new_ext_paragraph_ids) == len(new_ext_sentence_lengths)
        return new_ext_sentence_lengths, new_ext_paragraph_ids

    def shuffle_sentences(
        self,
        sequence_length: int,  # == L
        segmentA_length: int,  # int
        extended_sentence_lengths,  # <int>(#SENT), #SENT = #sentences + #sentinel tokens
        extended_paragraph_ids,  # <int>(#SENT)
        # Note: what extended_xxx looks like:
        # [-1, xxx of sent1, xxx of sent2, ..., -1, ..., xxx of sentN, -1]
    ):
        # To maintain a reasonable peak memory (partially affected by number of sentences),
        # we force sentence merging on sequence that has too much sentences in different
        # pargraphs indivisually (which is rare case and due to corrupted text data)
        if len(extended_sentence_lengths) > (self.config.max_sequence_length / 5):
            extended_sentence_lengths, extended_paragraph_ids = self.merge_sentences(
                extended_sentence_lengths, extended_paragraph_ids
            )

        # Get non-sentinel true sentences' attributes
        non_sentinel = extended_sentence_lengths != -1  # <bool>(#SENT)
        sentence_lengths = extended_sentence_lengths[non_sentinel]
        paragraph_ids = extended_paragraph_ids[non_sentinel]

        # Split sequence into sentences
        positions = torch.arange(sequence_length)  # <int>(L)
        non_sentinel_token = torch.full((sequence_length,), True)
        non_sentinel_token[0] = non_sentinel_token[-1] = False
        non_sentinel_token[segmentA_length - 1] = False
        nonsentinel_positions = positions[non_sentinel_token]  # <int>(L - #sentinels)
        senticized_positions = list(
            nonsentinel_positions.split(sentence_lengths.tolist())
        )  # List[Tensor[int]], original token positions associated with each sentence

        # Shuffle and Reconcatenate sentences into sequence to get permutation
        (
            token_permutation,
            sentence_permutation,
            new_segmentA_length,
        ) = self._shuffle_sentences(
            sentences=senticized_positions,
            sequence_length=sequence_length,
            old_segmentA_length=segmentA_length,
        )
        # reverse_permutation = torch.argsort(permutation)
        # assert torch.equal(token_ids, token_ids[permutation][reverse_permutation])

        # Mark token with the ordinal of sentence it belongs to
        sentence_ordinals = []
        th_sentence = 0
        for length in extended_sentence_lengths:
            if length == -1:
                sentence_ordinals.append(-1)
            else:
                assert length > 0
                sentence_ordinals += [th_sentence] * length
                th_sentence += 1
        sentence_ordinals = torch.tensor(sentence_ordinals)
        assert len(sentence_ordinals) == sequence_length
        sentence_marks = sentence_ordinals[token_permutation]

        return (
            token_permutation,  # <int>(L), L = number of tokens
            sentence_marks,  # <int>(L)
            new_segmentA_length,  # int
            sentence_permutation,  # <int>(S), S = number of sentences
            sentence_lengths,  # <int>(S)
            paragraph_ids,  # <int>(S)
        )

    def _shuffle_sentences(
        self,
        sentences: list[
            torch.LongTensor
        ],  # token positions associated with each sentence
        sequence_length: int,
        old_segmentA_length: int,
    ):
        if self.config.tsp_shuffling_fixing_prob >= 1:
            token_permutation = torch.arange(sequence_length)
            sentence_permutation = torch.arange(len(sentences))
            segmentA_length = old_segmentA_length
            return token_permutation, sentence_permutation, segmentA_length

        ## Shuffle sentences
        if len(sentences) == 1:
            sentence_permutation = torch.zeros(1, dtype=torch.long)
        else:
            sentence_permutation = self._partial_shuffling(len(sentences))
            sentences = [sentences[i] for i in sentence_permutation]
        ## Add back sentinels
        CLS_token_position = torch.tensor([0])
        middle_SEP_token_position = torch.tensor([old_segmentA_length - 1])
        end_SEP_token_position = torch.tensor([sequence_length - 1])
        segmentA_len = old_segmentA_length
        if old_segmentA_length != sequence_length:  # has two segments
            assert len(sentences) >= 2
            half_point = sequence_length // 2
            segmentA_len = 1 + len(sentences[0])  # cls and the first sentence
            _insert_idx = 1  # search inserting postiion after the first sentence
            for sentence in sentences[1:]:
                behind_half_point = segmentA_len > half_point
                will_behind_half_point = segmentA_len + len(sentence) > half_point
                if behind_half_point or (
                    will_behind_half_point and random.random() < 0.5
                ):  # The same way with preprocessing to insert the middle SEP
                    break  # insert SEP before this sentence
                segmentA_len += len(sentence)
                _insert_idx += 1
            if _insert_idx == len(sentences):  # insert after the last sentence
                segmentA_len -= len(sentence)
                _insert_idx -= 1
            assert (
                1 <= _insert_idx <= (len(sentences) - 1)
            ), "middle [SEP] should be after the first sentence and before the last sentence"
            sentences.insert(_insert_idx, middle_SEP_token_position)
            segmentA_len += 1
        token_permutation = torch.cat(
            [CLS_token_position, *sentences, end_SEP_token_position]
        )  # <int>(L)
        assert len(token_permutation) == sequence_length

        return token_permutation, sentence_permutation, segmentA_len

    def _partial_shuffling(self, n: int) -> list[int]:
        # There is expectedly always 1 item not moved after shuffling (https://math.stackexchange.com/questions/3745803/expected-number-of-cards-in-original-position-in-a-shuffled-deck-of-52-cards)
        fix_rate = self.config.tsp_shuffling_fixing_prob - 1 / n
        positions = torch.arange(n)
        k = min(round(n * fix_rate), n - 2)  # notice n >= 2 here
        to_shuffle = torch.randperm(n)[k:]
        partial_positions = positions[to_shuffle]
        partial_positions = partial_positions[torch.randperm(len(partial_positions))]
        positions[to_shuffle] = partial_positions
        return positions

    def manipulate_segments(
        self,
        ids: list[int],  # token ids of sequence
        segmentA_length: int,
        dataset,
        apply_permutation_immediately: bool,
        extended_sentence_lengths: list[int],  # Note: marked as -1 at sentinel
        document_id: int,
    ):
        cls_token_id = self.tokenizer.cls_token_id
        sep_token_id = self.tokenizer.sep_token_id

        # Split into two segments if the sequence has only one segment
        if segmentA_length == len(ids):
            if len(extended_sentence_lengths) == 3:  # if only one non-sentinel sentence
                sep_insert_idx = len(ids) // 2
            else:
                segA_num_sents = round(
                    len(extended_sentence_lengths) / 2 + random.uniform(-0.3, 0.3)
                )
                extended_sentence_lengths[0] = 1  # CLS
                sep_insert_idx = sum(extended_sentence_lengths[:segA_num_sents])
            ids.insert(sep_insert_idx, self.tokenizer.sep_token_id)
            segmentA_length = sep_insert_idx + 1
            if len(ids) > self.config.max_sequence_length:
                ids.pop(-2)

        # Decide which segment-based corruption to do for this sequence
        do_replacing = do_swap = False
        if self.config.inter_segment_task == "nsp":
            do_replacing = random.random() < 0.5
            is_manipulated = int(do_replacing)
        elif self.config.inter_segment_task == "sop":
            do_swap = random.random() < 0.5
            is_manipulated = int(do_swap)
        elif self.config.inter_segment_task == "sso":
            is_manipulated = random.choice([0, 1, 2])
            do_replacing = is_manipulated == 1
            do_swap = is_manipulated == 2
        assert not (do_replacing and do_swap)

        # Do the manipulation
        token_permutation = range(len(ids))
        new_segmentA_length = segmentA_length
        if do_replacing:  # replace segment B with text from the other document
            tgt_length = len(ids) - segmentA_length - 1
            for _ in range(10):
                entry = dataset[random.randint(0, len(dataset) - 1)]
                if (
                    entry["document_id"] != document_id
                    and (len(entry["input_ids"]) - 3) > tgt_length
                ):
                    break
            random_segmentA = entry["input_ids"][1 : entry["segmentA_length"] - 1]
            random_text = random_segmentA
            if len(random_text) < tgt_length:
                random_segmentB = entry["input_ids"][entry["segmentA_length"] : -1]
                random_text = torch.cat([random_text, random_segmentB])
            random_text = random_text[:tgt_length]
            assert (
                len(random_text) <= tgt_length
            )  # It is possible we can't find text long enough in 10 random drawals (but it's a super rare case)
            ids = [*ids[:segmentA_length], *random_text.tolist(), sep_token_id]
        elif do_swap:
            if apply_permutation_immediately:
                new_segmentA_length = 1 + len(ids[segmentA_length:])
                ids = [cls_token_id, *ids[segmentA_length:], *ids[1:segmentA_length]]
            else:
                pids = list(range(len(ids)))
                token_permutation = [
                    0,
                    *pids[segmentA_length:],
                    *pids[1:segmentA_length],
                ]

        return (
            torch.tensor(is_manipulated),  # <bool>(),
            torch.tensor(ids),  # <int>(L)
            new_segmentA_length,  # int
            torch.tensor(token_permutation),  # <int>(L), used by only ELECTRA + SOP/SSO
        )

