import torch
import datasets
import transformers


class TinyStoriesDataset(torch.utils.data.Dataset):
    def __init__(self, split, tokenizer=None, cfg=None):
        self.dataset = datasets.load_dataset("roneneldan/TinyStories", trust_remote_code=True)
        self.split = split
        self.dataset = self.dataset[split]
        if split == "train":
            self.dataset = self.dataset[:int(cfg.learning.batch_size * cfg.datasets.finetune_iters)]
            self.dataset = [{'text': self.dataset['text'][i]} for i in range(len(self.dataset['text']))]
        elif split == "validation":
            self.dataset = self.dataset[:int(cfg.learning.batch_size * cfg.datasets.val_iters)]
            self.dataset = [{'text': self.dataset['text'][i]} for i in range(len(self.dataset['text']))]
        elif split == "test":
            self.dataset = self.dataset[:int(cfg.learning.batch_size * cfg.datasets.test_iters)]
            self.dataset = [{'text': self.dataset['text'][i]} for i in range(len(self.dataset['text']))]
        self.tokenizer = tokenizer
        self.cfg = cfg

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]['text']
        if len(sample) <= 2:
            return self.__getitem__(idx + 1) # Skip empty samples
        if self.tokenizer:
            tokens = self.tokenizer(sample)
            if isinstance(tokens, transformers.tokenization_utils_base.BatchEncoding):
                # For transf1ormers tokenizers
                tokens = tokens['input_ids']
            elif type(tokens) == list: # GPT2 tokenizer for NanoGPT
                tokens = tokens
            else:
                raise ValueError("Tokenizer output must be a list or dictionary.")

            random_token = torch.randint(
                int(min(self.cfg.datasets.max_length, len(tokens) )/ 2), 
                min(len(tokens) - 2, self.cfg.datasets.max_length - 2),  # last token is EOS, don't want to predict that
                (1,)).item()

            # get word corresponding to random token index
            token_to_decode = tokens[random_token]
        
        else:
            random_token = None

        return sample, random_token