import os
import torch
import torch.utils.data as data
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import datasets
import transformers
from tqdm import tqdm

class TextDataset(data.Dataset):
    def __init__(self, text, eos_token, max_length = 256):
        self.text = text
        self.max_length = max_length
        self.eos_token = torch.tensor([eos_token])

    def __len__(self):
        return int(len(self.text) / self.max_length)
    
    def __getitem__(self, idx):
        idx = idx * self.max_length
        return {'input_ids': torch.concat([self.eos_token, self.text[idx : idx + self.max_length], self.eos_token])}

def process_wikitext(tokenizer, text):
    corpus = []
    for entry in tqdm(text):
        if len(entry) == 0:
            continue
        row = torch.tensor(tokenizer.encode(entry))
        corpus.append(row)
    corpus = torch.cat(corpus).long()
    return corpus
    
def make_dataset(max_length = 256):
    splits = ['train', 'validation', 'test']
    wikitext_data = datasets.load_dataset('wikitext', 'wikitext-103-raw-v1')
    wikitext_datasets = []

    for split in splits:
        if not os.path.exists(f'wikitext-103/wikitext_{split}.pt'):
            tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
            wikitext_split = wikitext_data[split]['text']
            corpus_split = process_wikitext(tokenizer, wikitext_split)
            if not os.path.exists('wikitext-103'):
                os.makedirs('wikitext-103')
            torch.save(corpus_split, f'wikitext-103/wikitext_{split}.pt')
        else:
            corpus_split = torch.load(f'wikitext-103/wikitext_{split}.pt')
        corpus_dataset = TextDataset(corpus_split, 50256, max_length)
        wikitext_datasets.append(corpus_dataset)

    return wikitext_datasets

def make_dataloaders(batch_size, num_workers = 4, max_length = 256):
    train_dataset, valid_dataset, test_dataset = make_dataset(max_length)
    return data.DataLoader(train_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = True), data.DataLoader(valid_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = False), data.DataLoader(test_dataset, batch_size = batch_size, num_workers = num_workers)

if __name__ == '__main__':

    make_dataset()

    train_loader, _, _ = make_dataloaders(3, num_workers = 4, max_length = 50)
    tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
    for batch in train_loader:
        input_ids = batch['input_ids']
        print(input_ids)
        print(tokenizer.decode(input_ids[0]))
        break