import torch
from nltk.tokenize import TreebankWordTokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import PennTreebank
from torch.utils.data import DataLoader, Dataset

def get_PennTree_dataset(seq_len):
    # Initialize the TreebankWordTokenizer
    tokenizer = TreebankWordTokenizer()
    
    # Helper function to yield tokens
    def yield_tokens(data_iter):
        for text in data_iter:
            yield tokenizer.tokenize(text)
    
    # Load the Penn Treebank dataset and build the vocabulary
    train_iter = PennTreebank(split='train')
    vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
    vocab.set_default_index(vocab["<unk>"])
    vocab_size = len(vocab)
    # Concatenate all text into one sequence of tokens
    def concatenate_texts(data_iter):
        tokens = []
        for text in data_iter:
            tokens.extend(vocab(tokenizer.tokenize(text)))
        return torch.tensor(tokens, dtype=torch.long)
    
    # Concatenate training data
    train_tokens = concatenate_texts(PennTreebank(split='train'))
    
    # Concatenate test data
    test_tokens = concatenate_texts(PennTreebank(split='test'))
    
    # Define a custom dataset class for random sampling
    class RandomSampleDataset(Dataset):
        def __init__(self, tokens, seq_len):
            self.tokens = tokens
            self.seq_len = seq_len
    
        def __len__(self):
            return len(self.tokens) - self.seq_len
    
        def __getitem__(self, idx):
            input_seq = self.tokens[idx:idx+self.seq_len]
            target_seq = self.tokens[idx+1:idx+self.seq_len+1]
            return input_seq, target_seq
    
    
    
    # Create the dataset and DataLoader for training data
    train_dataset = RandomSampleDataset(train_tokens, seq_len)
    
    # Create the dataset and DataLoader for test data
    test_dataset = RandomSampleDataset(test_tokens, seq_len)

    return train_dataset, test_dataset, vocab_size

