import torch
import numpy as np
import jax.numpy as jnp
from torch.utils.data import Dataset

class JAXTextDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len
        self.max_start = len(data) - seq_len - 1

    def __len__(self):
        # Just return an arbitrary large number or estimate
        return len(self.data) // self.seq_len

    def __getitem__(self, idx):
        start = np.random.randint(0, self.max_start + 1)
        seq = self.data[start : start + self.seq_len + 1]
        input_ids = jnp.array(seq[:-1], dtype=jnp.int32)
        labels = jnp.array(seq[1:], dtype=jnp.int32)
        attention_mask = jnp.ones_like(input_ids)
        return {'input_ids': input_ids,'attention_mask': attention_mask,'labels': labels}
