import numpy as np
import random
import torch
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.data.dataset import Dataset

# Set seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)

# Wrapper for tokenized input IDs
class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids

# Load and process wikitext2 dataset
def get_wikitext2_sample(nsamples, seed, seqlen, tokenizer):
    # Load train and test datasets
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    
    # traindata = load_dataset("/mnt/bd/pretraining/mjl_work/data/wikitext-2-raw-v1/", data_files={"train":'train-00000-of-00001-6506f33274247c0c.parquet'}, split='train')
    # testdata = load_dataset("/mnt/bd/pretraining/mjl_work/data/wikitext-2-raw-v1/", data_files={"test":'test-00000-of-00001-7231805191546d57.parquet'}, split='test')
    # print(traindata)
    
    # Encode datasets
    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


# Load and process c4 dataset
def get_c4_sample(nsamples, seed, seqlen, tokenizer):
    # Load train and validation datasets
    traindata = load_dataset('allenai/c4', "en", data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
    valdata = load_dataset('allenai/c4', "en", data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
    # 修改train和valdata为本地
    # traindata = load_dataset('/mnt/bd/pretraining/mjl_work/data/c4', data_files = {'train': 'c4-train.00000-of-01024.json.gz'}, split='train')
    # valdata = load_dataset('/mnt/bd/pretraining/mjl_work/data/c4', data_files = {'validation': 'c4-validation.00000-of-00008.json.gz'}, split='validation')
    # traindata = load_dataset('/home/tiger/work/c4', data_files=)

    # Generate samples from training set
    print("generating samples")
    random.seed(seed)
    trainloader = []
    valloader = []
    for _ in tqdm(range(nsamples)):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    # Prepare validation dataset
    # print("preparing validation dataset")
    # for _ in tqdm(range(len(valdata))):
    #     valloader.append([tokenizer(valdata[_]['text'], return_tensors='pt').input_ids, None])
    return trainloader, None

def get_c4_sample_for_train(nsamples, seqlen, tokenizer):
    # Load train and validation datasets
    traindata = load_dataset('allenai/c4', "en", data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
    valdata = load_dataset('allenai/c4', "en", data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
    # 修改train和valdata为本地
    # traindata = load_dataset('/mnt/bd/pretraining/mjl_work/data/c4', data_files = {'train': 'c4-train.00000-of-01024.json.gz'}, split='train')
    # valdata = load_dataset('/mnt/bd/pretraining/mjl_work/data/c4', data_files = {'validation': 'c4-validation.00000-of-00008.json.gz'}, split='validation')
    # traindata = load_dataset('/home/tiger/work/c4', data_files=)

    # Generate samples from training set
    print("generating samples")
    # random.seed(seed)
    trainloader = []
    valloader = []
    for _ in tqdm(range(len(traindata))):
        trainenc = tokenizer(traindata[_]['text'], return_tensors='pt')
        if trainenc.input_ids.shape[1] <= seqlen:
            continue
        for i in range(0, trainenc.input_ids.shape[1] - seqlen - 1, seqlen):
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            tar = inp.clone()
            tar[:, :-1] = -100
            trainloader.append((inp, tar))
            if len(trainloader) >= nsamples:
                break
        if len(trainloader) >= nsamples:
            break

    # Prepare validation dataset
    # print("preparing validation dataset")
    # for _ in tqdm(range(len(valdata))):
    #     valloader.append([tokenizer(valdata[_]['text'], return_tensors='pt').input_ids, None])
    return trainloader, None 


def get_ptb_sample(nsamples, seed, seqlen, tokenizer):

    traindata = load_dataset('./ptb_text_only', 'penn_treebank', split='train')
    testdata = load_dataset('./ptb_text_only', 'penn_treebank', split='validation')

    # Encode datasets
    trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['sentence']), return_tensors='pt')

    # Generate samples from training set using random seed and specified sequence length
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


# Load and process bokkdata dataset
def get_bookcorpus_sample(nsamples, seed, seqlen, tokenizer):

    # traindata = load_dataset('./bookcorpus/', split='train')
    traindata = load_dataset("parquet",data_files={"train":'./bookcorpus/train-00000-of-00053-550defad11191c81.parquet'},split='train' )

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    return trainloader

def get_wikitext2(seq_len, tokenizer):
    traindata = load_dataset("parquet", data_files={"train": '/mnt/bd/pretraining/mjl_work/data/wikitext-2-raw-v1/train-00000-of-00001-6506f33274247c0c.parquet'}, split='train')
    testdata = load_dataset("parquet", data_files={"test": '/mnt/bd/pretraining/mjl_work/data/wikitext-2-raw-v1/test-00000-of-00001-7231805191546d57.parquet'}, split='test')
    return traindata, testdata

def get_ptb(seq_len, tokenizer):
    traindata = load_dataset('./ptb_text_only', 'penn_treebank', split='train')
    valdata = load_dataset('./ptb_text_only', 'penn_treebank', split='validation')
    return traindata, valdata

def get_c4(seq_len, tokenizer):
    traindata = load_dataset('json', data_files={'train': './c4/c4-train.00000-of-01024.json'}, split='train')
    valdata = load_dataset('json', data_files={'validation': './c4/c4-validation.00000-of-00008.json'}, split='validation')
    valdata = valdata.select(range(5000))
    return traindata, valdata

class IndexDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(self.tensors)

def process_data(samples, tokenizer, seq_len, field_name):
    test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
    test_ids_batch = []
    nsamples = test_ids.numel() // seq_len

    for i in range(nsamples):
        batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
        test_ids_batch.append(batch)
    test_ids_batch = torch.stack(test_ids_batch)
    return IndexDataset(tensors=test_ids_batch)

def get_loaders_sample(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
    if 'wikitext2' in name:
        return get_wikitext2_sample(nsamples, seed, seqlen, tokenizer)
    elif "c4" in name:
        return get_c4_sample(nsamples, seed, seqlen, tokenizer)
    elif "bookcorpus" in name:
        return get_bookcorpus_sample(nsamples, seed, seqlen, tokenizer)
    elif "ptb" in name:
        return get_ptb_sample(nsamples, seed, seqlen, tokenizer)
    
# Function to select the appropriate loader based on dataset name
def get_loaders(name, tokenizer, seq_len=2048, batch_size=8):
    if 'wikitext2' in name:
        train_data, test_data = get_wikitext2(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len, 'text')
    if 'ptb' in name:
        train_data, test_data = get_ptb(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence')
    if 'c4' in name:
        train_data, test_data = get_c4(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len, 'text')

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_data, test_loader

