import torch
from torch.utils.data import Dataset
import numpy as np
import jax.lax as lax


class DatasetTorch(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = torch.from_numpy(np.array(inputs))
        self.labels = torch.from_numpy(np.array(labels))

    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, idx):
        input_batch = self.inputs[idx, :]
        label_batch = self.labels[idx, :]

        return input_batch, label_batch


class DataLoader:
    def __init__(self, inputs, targets, batch_size):
        self.inputs = inputs
        self.targets = targets
        self.bs = batch_size

        self.n = inputs.shape[0]
        self.dim = inputs.shape[1]
        self.len = self.n // self.bs

    def get_batch(self, i):
        batch_inputs = lax.dynamic_slice(self.inputs, (i*self.bs, 0), (self.bs, self.dim))
        batch_targets = lax.dynamic_slice(self.targets, (i * self.bs, 0), (self.bs, 1))

        return batch_inputs, batch_targets
