import random

from datasets import load_dataset

def get_wikitext2(nsamples, seed, seqlen, tokenizer, split="train"):
    # nsamples is only for train
    if split not in ["train", "test"]:
        raise Exception(f"No such split: {split}")
    data = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)

    enc = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
    if split == "test":
        nsamples = enc.input_ids.numel() // seqlen

    random.seed(seed)
    loader = []
    for idx in range(nsamples):
        if split == "train":
            i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
        else:
            i = idx * seqlen
        j = i + seqlen
        inp = enc.input_ids[:, i:j]
        tar = inp.clone()
        loader.append((inp, tar))
    return loader, enc


def get_longbench_context(dataset_name, seqlen, res_size, tokenizer, split=None):
    data = load_dataset('THUDM/LongBench', f"{dataset_name}_e", split='test')
    loader = []
    for i, data_sample in enumerate(data):
        tokenized = tokenizer(data_sample['context'], return_tensors="pt")
        total_len = tokenized.input_ids.shape[-1]
        total_len -= total_len%res_size
        total_len = min(total_len, seqlen)
        loader.append((tokenized.input_ids[:, :total_len], tokenized.input_ids[:, :total_len].clone()))
    return loader


def get_c4(nsamples, seed, seqlen, tokenizer, split="train"):
    # nsamples is only for train
    if split not in ["train", "test"]:
        raise Exception(f"No such split: {split}")
    
    if split == "train":
        data = load_dataset(
            'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
        )
    else:
        data = load_dataset(
            'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
        )

    import random
    random.seed(seed)
    loader = []

    if split == "train":
        for _ in range(nsamples):
            while True:
                i = random.randint(0, len(data) - 1)
                enc = tokenizer(data[i]['text'], return_tensors='pt')
                if enc.input_ids.shape[1] >= seqlen:
                    break
            if enc.input_ids.shape[1] - seqlen - 1 < 0:
                i = 0
            else:
                i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = enc.input_ids[:, i:j]
            tar = inp.clone()
            loader.append((inp, tar))
        return loader, enc

    else:
        # enc = tokenizer(' '.join(data[:1100]['text']), return_tensors='pt')
        # enc = enc.input_ids[:, :(256 * seqlen)]
        # nsamples = enc.numel() // seqlen
        # for idx in range(nsamples):
        #     i = idx * seqlen
        #     j = i + seqlen
        #     inp = enc[:, i:j]
        #     tar = inp.clone()
        #     loader.append((inp, tar))

        import random
        random.seed(0)
        enc = None
        loader = []
        for _ in range(256):
            while True:
                i = random.randint(0, len(data) - 1)
                tmp = tokenizer(data[i]['text'], return_tensors='pt')
                if tmp.input_ids.shape[1] >= seqlen:
                    break
            i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            loader.append((tmp.input_ids[:, i:j], tmp.input_ids[:, i:j].clone()))
        return loader, enc


def get_c4_ko(nsamples, seed, seqlen, tokenizer, split="train"):
    # nsamples is only for train
    if split not in ["train", "test"]:
        raise Exception(f"No such split: {split}")
    
    # referred to https://huggingface.co/datasets/bertin-project/mc4-sampling/raw/2acb6be7a6d41093cb0396d0fdc4de49daa5099d/mc4-sampling.py
    if split == "train":
        data = load_dataset(
            'json',
            data_files="https://huggingface.co/datasets/allenai/c4/resolve/1ddc917116b730e1859edef32896ec5c16be51d0/multilingual/c4-ko.tfrecord-00000-of-01024.json.gz",
            split='train'
        )
    else:
        data = load_dataset(
            'json',
            data_files="https://huggingface.co/datasets/allenai/c4/resolve/1ddc917116b730e1859edef32896ec5c16be51d0/multilingual/c4-ko-validation.tfrecord-00000-of-00001.json.gz",
            split='train'
        )

    import random
    random.seed(seed)
    loader = []

    if split == "train":
        for _ in range(nsamples):
            while True:
                i = random.randint(0, len(data) - 1)
                enc = tokenizer(data[i]['text'], return_tensors='pt')
                if enc.input_ids.shape[1] >= seqlen:
                    break
            if enc.input_ids.shape[1] - seqlen - 1 < 0:
                i = 0
            else:
                i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = enc.input_ids[:, i:j]
            tar = inp.clone()
            loader.append((inp, tar))
        return loader, enc

    else:
        # enc = tokenizer(' '.join(data[:1100]['text']), return_tensors='pt')
        # enc = enc.input_ids[:, :(256 * seqlen)]
        # nsamples = enc.numel() // seqlen
        # for idx in range(nsamples):
        #     i = idx * seqlen
        #     j = i + seqlen
        #     inp = enc[:, i:j]
        #     tar = inp.clone()
        #     loader.append((inp, tar))

        import random
        random.seed(0)
        enc = None
        loader = []
        for _ in range(256):
            while True:
                i = random.randint(0, len(data) - 1)
                tmp = tokenizer(data[i]['text'], return_tensors='pt')
                if tmp.input_ids.shape[1] >= seqlen:
                    break
            i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            loader.append((tmp.input_ids[:, i:j], tmp.input_ids[:, i:j].clone()))
        return loader, enc
