import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pickle




class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = torch.Tensor(data)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

# Custom transform to convert to float16 to avoid type error
class ToFloat16:
    def __call__(self, sample):
        return sample.to(torch.float16)
    

def load_datasets():
    # Load data from file
    with open(f'data/train_test_slices/train_slices.pkl', 'rb') as f:
        data = pickle.load(f)
    dataset = CustomDataset(data, transform=ToFloat16())
    return dataset


def create_dataloader(batch_size, dataset):# Create a DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader
