import random
import numpy as np
import torch

from datasets import load_dataset
from torch.utils.data.dataset import Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer

def get_wikitext2(seq_len, tokenizer):
    traindata = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split='test')
    return traindata, testdata

def get_ptb(seq_len, tokenizer):
    traindata = load_dataset('ptb-text-only/ptb_text_only', 'penn_treebank', split='train')
    valdata = load_dataset('ptb-text-only/ptb_text_only', 'penn_treebank', split='validation')
    return traindata, valdata

class IndexDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(self.tensors)

def process_data(OPT_ornot, samples, tokenizer, seq_len, field_name):
    if OPT_ornot == True:
        test_ids = torch.tensor(tokenizer("\n\n".join(samples[field_name]))["input_ids"])
    else:
        test_ids = torch.tensor(tokenizer.encode("\n\n".join(samples[field_name]), bos=False, eos=False))
    test_ids_batch = []
    nsamples = test_ids.numel() // seq_len

    for i in range(nsamples):
        batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
        test_ids_batch.append(batch)
    test_ids_batch = torch.stack(test_ids_batch)
    return IndexDataset(tensors=test_ids_batch)
       

def get_loaders(OPT_ornot, name, tokenizer, seq_len=2048, batch_size = 8):
    if 'wikitext2' in name:
        train_data, test_data = get_wikitext2(seq_len, tokenizer)
        test_dataset = process_data(OPT_ornot, test_data, tokenizer, seq_len, 'text')
        train_dataset = process_data(OPT_ornot, train_data, tokenizer, 128, 'text')
        
    if 'ptb' in name:
        train_data, test_data = get_ptb(seq_len, tokenizer)
        test_dataset = process_data(OPT_ornot, test_data, tokenizer, seq_len, 'sentence')
        train_dataset = process_data(OPT_ornot, train_data, tokenizer, seq_len, 'sentence')

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader