import math
from data.latent_dataset import LatentDataset
from torch.utils.data import DataLoader, random_split


def create_latent_dataloader(dataset, batch_size, shuffle, drop_last, **kwargs):
    return DataLoader(dataset=dataset, batch_size=batch_size,
                      shuffle=shuffle, drop_last=drop_last, **kwargs)


def create_latent_train_dataloader(dataset_path, batch_size, **kwargs):
    return create_latent_dataloader(dataset=LatentDataset(dataset_path), batch_size=batch_size,
                                    shuffle=True, drop_last=True, **kwargs)


def create_latent_train_valid_dataloaders(dataset_path, batch_size, perc_valid, **kwargs):
    assert 0 <= perc_valid <= 1
    ds = LatentDataset(dataset_path)
    valid_len = math.floor(len(ds) * perc_valid)
    train, valid = random_split(ds, [len(ds) - valid_len, valid_len])
    return create_latent_dataloader(dataset=train, batch_size=batch_size,
                                    shuffle=True, drop_last=True, **kwargs), \
           create_latent_dataloader(dataset=valid, batch_size=batch_size,
                                    shuffle=True, drop_last=False, **kwargs)


def create_latent_test_dataloader(dataset_path, batch_size, **kwargs):
    return create_latent_dataloader(dataset=LatentDataset(dataset_path), batch_size=batch_size,
                                    shuffle=False, drop_last=False, **kwargs)
