from torch.utils.data import Dataset
from torch.utils.data import SubsetRandomSampler, DataLoader
import torch
import numpy as np


def torch_train_val_split(dataset, batch_size, val_size=.2, shuffle=True, seed=None):
    # Creating data indices for training and validation splits:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    val_split = int(np.floor(val_size * dataset_size))
    
    if shuffle:
        # setting random seed of shuffling
        np.random.seed(seed)
        # shuffling indices
        np.random.shuffle(indices)
        
    # all remaining indices except first val_split indices correspond to train dataset
    train_indices = indices[val_split:]
    # first val_split indices correspond to validation dataset
    val_indices = indices[:val_split]

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    # initializign dataloaders
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              num_workers=2)
    
    val_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            sampler=val_sampler,
                            num_workers=2)
    
    return train_loader, val_loader


class MNISTDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.transform = transform
        X = torch.tensor(X)
        self.data = torch.reshape(X, (X.shape[0], 1, X.shape[1], X.shape[2])).float() / 255
        self.labels = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]

        if self.transform:
            image = self.transform(image)

        return image, self.labels[index]

