import datasets

# Use a pipeline as a high-level helper
from transformers import pipeline, PreTrainedTokenizerFast
from torch.utils.data import Dataset, DataLoader
from IPython import embed
import spacy
import os
import tqdm
import torch
import numpy as np
from torch.optim import AdamW

def isword(sequence):
    return (sequence[0] == " " and sequence[1:].isalnum()) or (sequence.strip() == ",")

class RandomPermutationDocDatasetWithBPE(Dataset):
    def __init__(self, tokenizer, hf_dataset, max_len=2048, autoregressive=False, bidirectional_rate=0.5):
        self.data = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.vanilla_permutation = False
        self.shuffle_sentence = False
        self.autoregressive = autoregressive
        self.rng = None
        self.bidirectional_rate = bidirectional_rate

    def __len__(self):
        return self.data.__len__()

    def truncated_Fisher_Yates(self, tokens, segments=None):
        if segments is None:
            segments = self.split_segments(tokens)

        permutation_ids = np.arange(segments.shape[1])
        words = [self.tokenizer.convert_tokens_to_string(tokens[segments[0, i]: segments[1, i]]) for i in range(segments.shape[1])]

        for i in range(len(permutation_ids)-1, 0, -1):
            word = words[permutation_ids[i]]
            if not isword(word):
                continue
            can_swap = []
            for j in range(i-1, -1, -1):
                if not isword(words[permutation_ids[j]]):
                    break
                else:
                    can_swap.append(j)
            if len(can_swap) == 0:
                continue
            can_swap.append(i)
            can_swap = np.array(can_swap)
            target_swap = can_swap[self.rng.integers(len(can_swap))]
            permutation_ids[i], permutation_ids[target_swap] = permutation_ids[target_swap], permutation_ids[i]
        return permutation_ids


    def generate_mask(self, tokens, token_i, token_j):
        word = self.tokenizer.convert_tokens_to_string(tokens[token_i: token_j])


        if not isword(word):
            return -np.ones((token_j - token_i,), dtype=np.int32)
        else:
            return np.zeros((token_j - token_i,), dtype=np.int32)

    def split_segments(self, tokens):
        segments_ij = np.zeros(shape=[2, self.max_len], dtype=np.int32)

        recorded_i = -1
        segments_num = 0

        for token_i in range(len(tokens)):
            token_decoded = self.tokenizer.convert_tokens_to_string(tokens[token_i:token_i + 1])
            if not token_decoded.isalpha():
                if recorded_i != -1:
                    segments_ij[0, segments_num] = recorded_i
                    segments_ij[1, segments_num] = token_i
                    segments_num += 1
                recorded_i = token_i
        segments_ij[0, segments_num] = recorded_i
        segments_ij[1, segments_num] = len(tokens)
        segments_num += 1
        return segments_ij[:, :segments_num]

    def proc_doc(self, target_doc, item_idx=None):
        input_ids_clean_list = []
        term_ids_clean_list = []
        permutation_ids_clean_list = []
        base = 0
        is_sent = lambda s: s.strip()[0:1].isalpha() and (s.strip()[-1] in '!.?' or s.strip()[-1].isalpha())
        for target_full in target_doc.split("</RESERVEDENDOFNATURALSENT>"):
            lines = target_full.split("\n")
            lines = [line + '\n' for line in lines[:-1]] + lines[-1:]
            for target in lines:
                if target == "":
                    continue
                not_a_natural_sent = not is_sent(target)
                tokens = self.tokenizer.tokenize(target)
                if len(tokens) >= self.max_len:
                    target = self.tokenizer.convert_tokens_to_string(tokens[0:self.max_len-2])
                    tokens = tokens[0:self.max_len-2]

                segments_ij = np.zeros(shape=[2, self.max_len], dtype=np.int32)

                recorded_i = -1
                segments_num = 0
                last_space_splitter = False
                for token_i in range(len(tokens)):
                    token_decoded = self.tokenizer.convert_tokens_to_string(tokens[token_i:token_i + 1])

                    if not token_decoded.isalpha() and not last_space_splitter and not token_decoded.isnumeric():
                        if recorded_i != -1:
                            segments_ij[0, segments_num] = recorded_i
                            segments_ij[1, segments_num] = token_i
                            segments_num += 1
                        recorded_i = token_i
                    last_space_splitter = False
                    if token_decoded.strip() == "":
                        last_space_splitter = True
                segments_ij[0, segments_num] = recorded_i
                segments_ij[1, segments_num] = len(tokens)
                segments_num += 1

                segments_ij = segments_ij[:, :segments_num]

                input_ids_clean = self.tokenizer.convert_tokens_to_ids(tokens)

                if self.autoregressive or not_a_natural_sent:
                    permutation = np.arange(segments_num)
                elif self.vanilla_permutation:
                    permutation = self.rng.permutation(segments_num)
                else:
                    permutation = self.truncated_Fisher_Yates(tokens, segments_ij)

                permutation_ids_clean = np.concatenate(
                    [np.arange(segments_ij[0, i], segments_ij[1, i]) for i in permutation]
                    # [np.arange(segments_ij[0, i], segments_ij[1, i]) for i in np.random.permutation(segments_num)]
                )

                if base + len(input_ids_clean) >= self.max_len:
                    break

                input_ids_clean_list.append(input_ids_clean)
                term_ids_clean = np.concatenate(
                    [self.generate_mask(tokens, segments_ij[0, i], segments_ij[1, i]) for i in permutation]
                )

                if (term_ids_clean == 0).astype(np.int32).sum() > 0:
                    idx = ((term_ids_clean == 0).astype(np.int32) * np.arange(len(term_ids_clean))).argmax()
                    term_ids_clean[idx] = 1
                    term_ids_clean[-1] = 1

                if self.vanilla_permutation:
                    term_ids_clean = np.zeros_like(term_ids_clean)
                    term_ids_clean[-1] = 1

                term_ids_clean_list.append(term_ids_clean)

                # if self.vanilla_permutation:
                #     permutation_ids_clean = self.rng.permutation(permutation_ids_clean.shape[0])
                permutation_ids_clean_list.append(permutation_ids_clean + base)
                base += len(input_ids_clean)

        input_ids_clean = np.concatenate(input_ids_clean_list)
        if input_ids_clean_list == []:
            print("Warning, input_ids_clean is empty")
            print(target_doc)
            if item_idx is not None:
                print(item_idx)
            embed()
            exit()
        if self.shuffle_sentence:
            sent_shuffs = self.rng.permutation(len(permutation_ids_clean_list))
            term_ids_clean_list = [term_ids_clean_list[i] for i in sent_shuffs]
            permutation_ids_clean_list = [permutation_ids_clean_list[i] for i in sent_shuffs]
        term_ids_clean = np.concatenate(term_ids_clean_list)
        permutation_ids_clean = np.concatenate(permutation_ids_clean_list)

        input_ids = torch.ones(self.max_len, dtype=torch.long) * self.tokenizer.pad_token_id
        permutation_ids = -torch.ones(self.max_len, dtype=torch.long)
        term_ids = -torch.ones(self.max_len, dtype=torch.long)

        input_ids[0: len(input_ids_clean)] = torch.tensor(input_ids_clean)

        permutation_ids[0: len(permutation_ids_clean)] = torch.tensor(permutation_ids_clean)

        term_ids[0: len(term_ids_clean)] = torch.tensor(term_ids_clean)

        if self.rng.random() > self.bidirectional_rate:
            return {
                "input_ids": input_ids,
                "permutation_ids": permutation_ids,
                "term_ids": term_ids,
                "bidirectional": torch.zeros_like(input_ids),
            }
        else:

            bidirectional = torch.zeros(self.max_len, dtype=torch.long)

            predicted = torch.zeros(self.max_len, dtype=torch.long)

            # while predicted.sum() == 0:
            predicted = self.rng.uniform(low=1., high=len(input_ids_clean))

            bidirectional[0:len(input_ids_clean) - int(predicted)] = 1

            return {
                "input_ids": input_ids,
                "permutation_ids": permutation_ids,
                "term_ids": term_ids,
                "bidirectional": bidirectional,
            }



    def __getitem__(self, item):
        if self.rng is None:
            self.rng = np.random.Generator(np.random.PCG64DXSM(seed=item))
        target_doc = self.data[item]['text']
        return self.proc_doc(target_doc, item_idx=item)


def main():
    tokenizer = PreTrainedTokenizerFast(tokenizer_file="insnet_tokenizer.json")
    tokenizer.pad_token = "<|endoftext|>"
    tokenizer.eos_token = "<|endoftext|>"
    tokenizer.bos_token = "<|endoftext|>"
    redo = False

    nlp = spacy.load('en_core_web_sm')

    short_data = datasets.load_from_disk("./wikipedia2023")
    dataset = RandomPermutationDocDatasetWithBPE(tokenizer, hf_dataset=short_data, autoregressive=False)


    limit = 128
    def process(x):
        text = x['text']
        # first_split = text.split("\n")
        # first_split = [x+"\n" for x in first_split]

        all = []
        current = []
        accumulated = 0
        for subtext in text:
            if subtext != "":
                doc = nlp(subtext)
                strings_idx = [sent.end_char for sent in doc.sents]
                strings = []
                i = 0
                for strings_j in strings_idx:
                    strings.append(subtext[i: strings_j])
                    i = strings_j
                for sent in strings:
                    token_len = len(tokenizer.tokenize(sent))
                    current.append(sent)
                    if accumulated + token_len > limit - 2:
                        if current is not []:
                            all.append("</RESERVEDENDOFNATURALSENT>".join(current))
                            if len(current) > 1:
                                current = [sent]
                                accumulated = token_len
                            else:
                                current = []
                                accumulated = 0

                        continue
                    else:
                        accumulated += token_len

        return {
            'text': all}

    if not os.path.exists('./wiki%d' % limit) or redo:
        dataset_raw = datasets.load_dataset("wikimedia/wikipedia", "20231101.en")['train']

        # dataset = dataset.select(range(512))
        dataset_new = dataset_raw.map(process, batched=True, batch_size=16, num_proc=128, remove_columns=dataset_raw.column_names)

        dataset_new.save_to_disk('./wiki%d' % limit)
        print("Wikidata Done.")
        exit()
    else:
        dataset_raw = datasets.load_dataset("wikimedia/wikipedia", "20231101.en")['train']
        dataset_new = datasets.load_from_disk('./wiki%d' % limit)
        print("Wikidata Done.")

    def process_gutenberg(x):
        text = x['text']

        all = []
        current = []
        accumulated = 0
        for subtext in text:
            first_split = subtext.split("\n\n")
            first_split = [x+"\n\n" for x in first_split]
            for subsubtext in first_split:
                if subtext != "":
                    doc = nlp(subsubtext)
                    strings_idx = [sent.end_char for sent in doc.sents]
                    strings = []
                    i = 0
                    for strings_j in strings_idx:
                        strings.append(subsubtext[i: strings_j])
                        i = strings_j
                    for sent in strings:
                        token_len = len(tokenizer.tokenize(sent))
                        current.append(sent)
                        if accumulated + token_len > limit - 2:
                            if current is not []:
                                all.append("</RESERVEDENDOFNATURALSENT>".join(current))
                                if len(current) > 1:
                                    current = [sent]
                                    accumulated = token_len
                                else:
                                    current = []
                                    accumulated = 0

                            continue
                        else:
                            accumulated += token_len

        return {
            'text': all}



    if not os.path.exists('./guten%d' % limit) or redo:

        dataset_raw = datasets.load_dataset("manu/project_gutenberg", split="en")

        # dataset = dataset.select(range(512))
        dataset_new = dataset_raw.map(process_gutenberg, batched=True, batch_size=1, num_proc=128, remove_columns=dataset_raw.column_names)

        dataset_new.save_to_disk('./guten%d' % limit)
        print("Bookdata Done.")
        exit()
    else:
        dataset_raw = datasets.load_dataset("manu/project_gutenberg", split="en")
        dataset_new = datasets.load_from_disk('./guten%d' % limit)
        print("Bookdata Done.")


    if not os.path.exists('./bc%d' % limit) or redo:

        dataset_raw = datasets.load_dataset("bookcorpus/bookcorpus")['train']

        # dataset = dataset.select(range(512))
        dataset_new = dataset_raw.map(process, batched=True, batch_size=256, num_proc=128, remove_columns=dataset_raw.column_names)

        dataset_new.save_to_disk('./bc%d' % limit)
        print("Bookdata Done.")
        exit()
    else:
        dataset_raw = datasets.load_dataset("bookcorpus/bookcorpus")['train']
        dataset_new = datasets.load_from_disk('./bc%d' % limit)
        print("Bookdata Done.")


    dataset = RandomPermutationDocDatasetWithBPE(tokenizer, hf_dataset=dataset_new, autoregressive=False)
    a = dataset.proc_doc("\n\n\n\n\n\n\n\n\n\n\n\n\n")
    # dataset[0]
    embed()
    exit()



if __name__ == "__main__":
    main()