from typing import Tuple
import torch
from torch.utils.data import Dataset
from datasets import load_dataset


class RandomTokens(Dataset):
    def __init__(self, tokenizer, seq_len, size=1_000, seed=None, **kwargs):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.size = size

        special_ids = set(
            [
                getattr(tokenizer, a)
                for a in [
                    "eos_token_id",
                    "pad_token_id",
                    "unk_token_id",
                ]
                if getattr(tokenizer, a, None) is not None
            ]
        )

        self.allowed_ids = torch.tensor(
            [i for i in range(tokenizer.vocab_size) if i not in special_ids]
        ).long()

        self._rand = torch.Generator()
        if seed is not None:
            self._rand.manual_seed(int(seed))

    def __len__(self):
        return self.size

    def __getitem__(self, index) -> Tuple[torch.Tensor, ...]:
        idxs = torch.randint(
            low=0,
            high=len(self.allowed_ids),
            size=(self.seq_len,),
            generator=self._rand,
            dtype=torch.long,
        )
        input_ids = self.allowed_ids[idxs]
        attention_mask = torch.ones(self.seq_len, dtype=torch.long)
        return input_ids, attention_mask


def get_llm_dataset(hub_path, split="train", **kwargs):
    path, name = hub_path.split("/")
    return load_dataset(path, name, split=split, **kwargs)
